micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5import os 6import pickle 7import hashlib 8import warnings 9from pathlib import Path 10from collections import OrderedDict 11from typing import Any, Dict, Iterable, Optional, Tuple, Union, Callable 12 13import zarr 14import vigra 15import torch 16import pooch 17import xxhash 18import numpy as np 19import imageio.v3 as imageio 20from skimage.measure import regionprops 21from skimage.segmentation import relabel_sequential 22 23from elf.io import open_file 24 25from nifty.tools import blocking 26 27from .__version__ import __version__ 28from . import models as custom_models 29 30try: 31 # Avoid import warnigns from mobile_sam 32 with warnings.catch_warnings(): 33 warnings.simplefilter("ignore") 34 from mobile_sam import sam_model_registry, SamPredictor 35 VIT_T_SUPPORT = True 36except ImportError: 37 from segment_anything import sam_model_registry, SamPredictor 38 VIT_T_SUPPORT = False 39 40try: 41 from napari.utils import progress as tqdm 42except ImportError: 43 from tqdm import tqdm 44 45# This is the default model used in micro_sam 46# Currently it is set to vit_b_lm 47_DEFAULT_MODEL = "vit_b_lm" 48 49# The valid model types. Each type corresponds to the architecture of the 50# vision transformer used within SAM. 51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 52 53 54# TODO define the proper type for image embeddings 55ImageEmbeddings = Dict[str, Any] 56"""@private""" 57 58 59def get_cache_directory() -> None: 60 """Get micro-sam cache directory location. 61 62 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 63 """ 64 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 65 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 66 return cache_directory 67 68 69# 70# Functionality for model download and export 71# 72 73 74def microsam_cachedir() -> None: 75 """Return the micro-sam cache directory. 76 77 Returns the top level cache directory for micro-sam models and sample data. 78 79 Every time this function is called, we check for any user updates made to 80 the MICROSAM_CACHEDIR os environment variable since the last time. 81 """ 82 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 83 return cache_directory 84 85 86def models(): 87 """Return the segmentation models registry. 88 89 We recreate the model registry every time this function is called, 90 so any user changes to the default micro-sam cache directory location 91 are respected. 92 """ 93 94 # We use xxhash to compute the hash of the models, see 95 # https://github.com/computational-cell-analytics/micro-sam/issues/283 96 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 97 # To generate the xxh128 hash: 98 # xxh128sum filename 99 encoder_registry = { 100 # The default segment anything models: 101 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 102 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 103 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 104 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 105 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 106 # The current version of our models in the modelzoo. 107 # LM generalist models: 108 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 109 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 110 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 111 # EM models: 112 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 113 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 114 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 115 # Histopathology models: 116 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 117 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 118 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 119 # Medical Imaging models: 120 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 121 } 122 # Additional decoders for instance segmentation. 123 decoder_registry = { 124 # LM generalist models: 125 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 126 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 127 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 128 # EM models: 129 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 130 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 131 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 132 # Histopathology models: 133 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 134 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 135 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 136 } 137 registry = {**encoder_registry, **decoder_registry} 138 139 encoder_urls = { 140 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 141 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 142 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 143 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 144 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 145 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 146 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 147 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 148 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 149 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 150 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 151 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 152 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 153 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 154 } 155 156 decoder_urls = { 157 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 158 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 159 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 160 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 161 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 162 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 163 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 164 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 165 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 166 } 167 urls = {**encoder_urls, **decoder_urls} 168 169 models = pooch.create( 170 path=os.path.join(microsam_cachedir(), "models"), 171 base_url="", 172 registry=registry, 173 urls=urls, 174 ) 175 return models 176 177 178def _get_default_device(): 179 # check that we're in CI and use the CPU if we are 180 # otherwise the tests may run out of memory on MAC if MPS is used. 181 if os.getenv("GITHUB_ACTIONS") == "true": 182 return "cpu" 183 # Use cuda enabled gpu if it's available. 184 if torch.cuda.is_available(): 185 device = "cuda" 186 # As second priority use mps. 187 # See https://pytorch.org/docs/stable/notes/mps.html for details 188 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 189 print("Using apple MPS device.") 190 device = "mps" 191 # Use the CPU as fallback. 192 else: 193 device = "cpu" 194 return device 195 196 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 198 """Get the torch device. 199 200 If no device is passed the default device for your system is used. 201 Else it will be checked if the device you have passed is supported. 202 203 Args: 204 device: The input device. By default, selects the best available device supports. 205 206 Returns: 207 The device. 208 """ 209 if device is None or device == "auto": 210 device = _get_default_device() 211 else: 212 device_type = device if isinstance(device, str) else device.type 213 if device_type.lower() == "cuda": 214 if not torch.cuda.is_available(): 215 raise RuntimeError("PyTorch CUDA backend is not available.") 216 elif device_type.lower() == "mps": 217 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 218 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 219 elif device_type.lower() == "cpu": 220 pass # cpu is always available 221 else: 222 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 223 224 return device 225 226 227def _available_devices(): 228 available_devices = [] 229 for i in ["cuda", "mps", "cpu"]: 230 try: 231 device = get_device(i) 232 except RuntimeError: 233 pass 234 else: 235 available_devices.append(device) 236 return available_devices 237 238 239# We write a custom unpickler that skips objects that cannot be found instead of 240# throwing an AttributeError or ModueNotFoundError. 241# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 242# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 243class _CustomUnpickler(pickle.Unpickler): 244 def find_class(self, module, name): 245 try: 246 return super().find_class(module, name) 247 except (AttributeError, ModuleNotFoundError) as e: 248 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 249 return None 250 251 252def _compute_hash(path, chunk_size=8192): 253 hash_obj = xxhash.xxh128() 254 with open(path, "rb") as f: 255 chunk = f.read(chunk_size) 256 while chunk: 257 hash_obj.update(chunk) 258 chunk = f.read(chunk_size) 259 hash_val = hash_obj.hexdigest() 260 return f"xxh128:{hash_val}" 261 262 263# Load the state from a checkpoint. 264# The checkpoint can either contain a sam encoder state 265# or it can be a checkpoint for model finetuning. 266def _load_checkpoint(checkpoint_path): 267 # Over-ride the unpickler with our custom one. 268 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 269 custom_pickle = pickle 270 custom_pickle.Unpickler = _CustomUnpickler 271 272 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 273 if "model_state" in state: 274 # Copy the model weights from torch_em's training format. 275 model_state = state["model_state"] 276 sam_prefix = "sam." 277 model_state = OrderedDict( 278 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 279 ) 280 else: 281 model_state = state 282 283 return state, model_state 284 285 286def _download_sam_model(model_type, progress_bar_factory=None): 287 model_registry = models() 288 289 progress_bar = True 290 # Check if we have to download the model. 291 # If we do and have a progress bar factory, then we over-write the progress bar. 292 if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None: 293 progress_bar = progress_bar_factory(model_type) 294 295 checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar) 296 if not isinstance(progress_bar, bool): # Close the progress bar when the task finishes. 297 progress_bar.close() 298 299 model_hash = model_registry.registry[model_type] 300 301 # If we have a custom model then we may also have a decoder checkpoint. 302 # Download it here, so that we can add it to the state. 303 decoder_name = f"{model_type}_decoder" 304 decoder_path = model_registry.fetch( 305 decoder_name, progressbar=True 306 ) if decoder_name in model_registry.registry else None 307 308 return checkpoint_path, model_hash, decoder_path 309 310 311def get_sam_model( 312 model_type: str = _DEFAULT_MODEL, 313 device: Optional[Union[str, torch.device]] = None, 314 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 315 return_sam: bool = False, 316 return_state: bool = False, 317 peft_kwargs: Optional[Dict] = None, 318 flexible_load_checkpoint: bool = False, 319 progress_bar_factory: Optional[Callable] = None, 320 **model_kwargs, 321) -> SamPredictor: 322 r"""Get the Segment Anything Predictor. 323 324 This function will download the required model or load it from the cached weight file. 325 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 326 The name of the requested model can be set via `model_type`. 327 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 328 for an overview of the available models 329 330 Alternatively this function can also load a model from weights stored in a local filepath. 331 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 332 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 333 a SAM model with vit_b encoder. 334 335 By default the models are downloaded to a folder named 'micro_sam/models' 336 inside your default cache directory, eg: 337 * Mac: ~/Library/Caches/<AppName> 338 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 339 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 340 See the pooch.os_cache() documentation for more details: 341 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 342 343 Args: 344 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 345 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 346 device: The device for the model. If 'None' is provided, will use GPU if available. 347 checkpoint_path: The path to a file with weights that should be used instead of using the 348 weights corresponding to `model_type`. If given, `model_type` must match the architecture 349 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 350 then `model_type` must be given as 'vit_b'. 351 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 352 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 353 peft_kwargs: Keyword arguments for th PEFT wrapper class. 354 If passed 'None', it does not initialize any parameter efficient finetuning. 355 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 356 By default, set to 'False'. 357 progress_bar_factory: A function to create a progress bar for the model download. 358 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 359 360 Returns: 361 The Segment Anything predictor. 362 """ 363 device = get_device(device) 364 365 # We support passing a local filepath to a checkpoint. 366 # In this case we do not download any weights but just use the local weight file, 367 # as it is, without copying it over anywhere or checking it's hashes. 368 369 # checkpoint_path has not been passed, we download a known model and derive the correct 370 # URL from the model_type. If the model_type is invalid pooch will raise an error. 371 _provided_checkpoint_path = checkpoint_path is not None 372 if checkpoint_path is None: 373 checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory) 374 375 # checkpoint_path has been passed, we use it instead of downloading a model. 376 else: 377 # Check if the file exists and raise an error otherwise. 378 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 379 # (If it isn't the model creation will fail below.) 380 if not os.path.exists(checkpoint_path): 381 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 382 model_hash = _compute_hash(checkpoint_path) 383 decoder_path = None 384 385 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 386 # before calling sam_model_registry. 387 abbreviated_model_type = model_type[:5] 388 if abbreviated_model_type not in _MODEL_TYPES: 389 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 390 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 391 raise RuntimeError( 392 "'mobile_sam' is required for the vit-tiny. " 393 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 394 ) 395 396 state, model_state = _load_checkpoint(checkpoint_path) 397 398 if _provided_checkpoint_path: 399 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 400 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 401 from micro_sam.models.build_sam import _validate_model_type 402 _provided_model_type = _validate_model_type(model_state) 403 404 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 405 # Otherwise replace 'abbreviated_model_type' with the later. 406 if abbreviated_model_type != _provided_model_type: 407 # Printing the message below to avoid any filtering of warnings on user's end. 408 print( 409 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 410 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 411 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 412 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 413 ) 414 415 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 416 abbreviated_model_type = _provided_model_type 417 418 # Whether to update parameters necessary to initialize the model 419 if model_kwargs: # Checks whether model_kwargs have been provided or not 420 if abbreviated_model_type == "vit_t": 421 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 422 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 423 424 else: 425 sam = sam_model_registry[abbreviated_model_type]() 426 427 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 428 # Overwrites the SAM model by freezing the backbone and allow PEFT. 429 if peft_kwargs and isinstance(peft_kwargs, dict): 430 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 431 peft_kwargs.pop("quantize", None) 432 433 if abbreviated_model_type == "vit_t": 434 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 435 436 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 437 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 438 if flexible_load_checkpoint: 439 sam = _handle_checkpoint_loading(sam, model_state) 440 else: 441 sam.load_state_dict(model_state) 442 sam.to(device=device) 443 444 predictor = SamPredictor(sam) 445 predictor.model_type = abbreviated_model_type 446 predictor._hash = model_hash 447 predictor.model_name = model_type 448 predictor.checkpoint_path = checkpoint_path 449 450 # Add the decoder to the state if we have one and if the state is returned. 451 if decoder_path is not None and return_state: 452 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 453 454 if return_sam and return_state: 455 return predictor, sam, state 456 if return_sam: 457 return predictor, sam 458 if return_state: 459 return predictor, state 460 return predictor 461 462 463def _handle_checkpoint_loading(sam, model_state): 464 # Whether to handle the mismatch issues in a bit more elegant way. 465 # eg. while training for multi-class semantic segmentation in the mask encoder, 466 # parameters are updated - leading to "size mismatch" errors 467 468 new_state_dict = {} # for loading matching parameters 469 mismatched_layers = [] # for tracking mismatching parameters 470 471 reference_state = sam.state_dict() 472 473 for k, v in model_state.items(): 474 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 475 if reference_state[k].size() == v.size(): 476 new_state_dict[k] = v 477 else: 478 mismatched_layers.append(k) 479 480 reference_state.update(new_state_dict) 481 482 if len(mismatched_layers) > 0: 483 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 484 485 for mlayer in mismatched_layers: 486 if 'weight' in mlayer: 487 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 488 elif 'bias' in mlayer: 489 reference_state[mlayer].zero_() 490 491 sam.load_state_dict(reference_state) 492 493 return sam 494 495 496def export_custom_sam_model( 497 checkpoint_path: Union[str, os.PathLike], 498 model_type: str, 499 save_path: Union[str, os.PathLike], 500 with_segmentation_decoder: bool = False, 501 prefix: str = "sam.", 502) -> None: 503 """Export a finetuned Segment Anything Model to the standard model format. 504 505 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 506 507 Args: 508 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 509 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 510 save_path: Where to save the exported model. 511 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 512 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 513 prefix: The prefix to remove from the model parameter keys. 514 """ 515 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 516 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 517 518 # Store the 'decoder_state' as well, if desired. 519 if with_segmentation_decoder: 520 if "decoder_state" not in state: 521 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 522 decoder_state = state["decoder_state"] 523 save_state = {"model_state": model_state, "decoder_state": decoder_state} 524 else: 525 save_state = model_state 526 527 torch.save(save_state, save_path) 528 529 530def export_custom_qlora_model( 531 checkpoint_path: Optional[Union[str, os.PathLike]], 532 finetuned_path: Union[str, os.PathLike], 533 model_type: str, 534 save_path: Union[str, os.PathLike], 535) -> None: 536 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 537 538 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 539 540 Args: 541 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 542 finetuned_path: The path to the new finetuned model, using QLoRA. 543 model_type: The Segment Anything Model type corresponding to the checkpoint. 544 save_path: Where to save the exported model. 545 """ 546 # Step 1: Get the base SAM model: used to start finetuning from. 547 _, sam = get_sam_model( 548 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 549 ) 550 551 # Step 2: Load the QLoRA-style finetuned model. 552 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 553 554 # Step 3: Identify LoRA layers from QLoRA model. 555 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 556 # - then copy the LoRA layers from the QLoRA model to the new state dict 557 updated_model_state = {} 558 559 modified_attn_layers = set() 560 modified_mlp_layers = set() 561 562 for k, v in ft_model_state.items(): 563 if "blocks." in k: 564 layer_id = int(k.split("blocks.")[1].split(".")[0]) 565 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 566 modified_attn_layers.add(layer_id) 567 updated_model_state[k] = v 568 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 569 modified_mlp_layers.add(layer_id) 570 updated_model_state[k] = v 571 572 # Step 4: Next, we get all the remaining parameters from the base SAM model. 573 for k, v in sam.state_dict().items(): 574 if "blocks." in k: 575 layer_id = int(k.split("blocks.")[1].split(".")[0]) 576 if k.find("attn.qkv.") != -1: 577 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 578 k = k.replace("qkv", "qkv.qkv_proj") 579 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 580 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 581 k = k.replace("mlp.", "mlp.mlp_layer.") 582 updated_model_state[k] = v 583 584 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 585 ft_state['model_state'] = updated_model_state 586 587 # Step 6: Store the new "state" to "save_path" 588 torch.save(ft_state, save_path) 589 590 591def get_model_names() -> Iterable: 592 model_registry = models() 593 model_names = model_registry.registry.keys() 594 return model_names 595 596 597# 598# Functionality for precomputing image embeddings. 599# 600 601 602def _to_image(input_): 603 # we require the input to be uint8 604 if input_.dtype != np.dtype("uint8"): 605 # first normalize the input to [0, 1] 606 input_ = input_.astype("float32") - input_.min() 607 input_ = input_ / input_.max() 608 # then bring to [0, 255] and cast to uint8 609 input_ = (input_ * 255).astype("uint8") 610 611 if input_.ndim == 2: 612 image = np.concatenate([input_[..., None]] * 3, axis=-1) 613 elif input_.ndim == 3 and input_.shape[-1] == 3: 614 image = input_ 615 else: 616 raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") 617 618 # explicitly return a numpy array for compatibility with torchvision 619 # because the input_ array could be something like dask array 620 return np.array(image) 621 622 623@torch.no_grad 624def _compute_embeddings_batched(predictor, batched_images): 625 predictor.reset_image() 626 batched_tensors, original_sizes, input_sizes = [], [], [] 627 628 # Apply proeprocessing to all images in the batch, and then stack them. 629 # Note: after the transformation the images are all of the same size, 630 # so they can be stacked and processed as a batch, even if the input images were of different size. 631 for image in batched_images: 632 tensor = predictor.transform.apply_image(image) 633 tensor = torch.as_tensor(tensor, device=predictor.device) 634 tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :] 635 636 original_sizes.append(image.shape[:2]) 637 input_sizes.append(tensor.shape[-2:]) 638 639 tensor = predictor.model.preprocess(tensor) 640 batched_tensors.append(tensor) 641 642 batched_tensors = torch.cat(batched_tensors) 643 features = predictor.model.image_encoder(batched_tensors) 644 645 predictor.original_size = original_sizes[-1] 646 predictor.input_size = input_sizes[-1] 647 predictor.features = features[-1] 648 predictor.is_image_set = True 649 650 return features, original_sizes, input_sizes 651 652 653# Wrapper of zarr.create dataset to support zarr v2 and zarr v3. 654def _create_dataset_with_data(group, name, data, chunks=None): 655 zarr_major_version = int(zarr.__version__.split(".")[0]) 656 if chunks is None: 657 chunks = data.shape 658 if zarr_major_version == 2: 659 ds = group.create_dataset( 660 name, data=data, shape=data.shape, compression="gzip", chunks=chunks 661 ) 662 elif zarr_major_version == 3: 663 ds = group.create_array( 664 name, shape=data.shape, compressors=[zarr.codecs.GzipCodec()], chunks=chunks, dtype=data.dtype, 665 ) 666 ds[:] = data 667 else: 668 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 669 return ds 670 671 672def _create_dataset_without_data(group, name, shape, dtype, chunks): 673 zarr_major_version = int(zarr.__version__.split(".")[0]) 674 if zarr_major_version == 2: 675 ds = group.create_dataset( 676 name, shape=shape, dtype=dtype, compression="gzip", chunks=chunks 677 ) 678 elif zarr_major_version == 3: 679 ds = group.create_array( 680 name, shape=shape, compressors=[zarr.codecs.GzipCodec()], chunks=chunks, dtype=dtype 681 ) 682 else: 683 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 684 return ds 685 686 687def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 688 tiling = blocking([0, 0], input_.shape[:2], tile_shape) 689 n_tiles = tiling.numberOfBlocks 690 691 features = f.require_group("features") 692 features.attrs["shape"] = input_.shape[:2] 693 features.attrs["tile_shape"] = tile_shape 694 features.attrs["halo"] = halo 695 696 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled") 697 698 n_batches = int(np.ceil(n_tiles / batch_size)) 699 for batch_id in range(n_batches): 700 tile_start = batch_id * batch_size 701 tile_stop = min(tile_start + batch_size, n_tiles) 702 703 batched_images = [] 704 for tile_id in range(tile_start, tile_stop): 705 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 706 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 707 tile_input = _to_image(input_[outer_tile]) 708 batched_images.append(tile_input) 709 710 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 711 for i, tile_id in enumerate(range(tile_start, tile_stop)): 712 tile_embeddings, original_size, input_size = batched_embeddings[i], original_sizes[i], input_sizes[i] 713 # Unsqueeze the channel axis of the tile embeddings. 714 tile_embeddings = tile_embeddings.unsqueeze(0) 715 ds = _create_dataset_with_data(features, str(tile_id), data=tile_embeddings.cpu().numpy()) 716 ds.attrs["original_size"] = original_size 717 ds.attrs["input_size"] = input_size 718 pbar_update(1) 719 720 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 721 return features 722 723 724def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 725 assert input_.ndim == 3 726 727 shape = input_.shape[1:] 728 tiling = blocking([0, 0], shape, tile_shape) 729 n_tiles = tiling.numberOfBlocks 730 731 features = f.require_group("features") 732 features.attrs["shape"] = shape 733 features.attrs["tile_shape"] = tile_shape 734 features.attrs["halo"] = halo 735 736 n_slices = input_.shape[0] 737 pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled") 738 739 # We batch across the z axis. 740 n_batches = int(np.ceil(n_slices / batch_size)) 741 742 for tile_id in range(n_tiles): 743 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 744 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 745 746 ds = None 747 for batch_id in range(n_batches): 748 z_start = batch_id * batch_size 749 z_stop = min(z_start + batch_size, n_slices) 750 751 batched_images = [] 752 for z in range(z_start, z_stop): 753 tile_input = _to_image(input_[z][outer_tile]) 754 batched_images.append(tile_input) 755 756 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 757 for i, z in enumerate(range(z_start, z_stop)): 758 tile_embeddings = batched_embeddings[i].unsqueeze(0) 759 if ds is None: 760 shape = (n_slices,) + tile_embeddings.shape 761 chunks = (1,) + tile_embeddings.shape 762 ds = _create_dataset_without_data( 763 features, str(tile_id), shape=shape, dtype="float32", chunks=chunks 764 ) 765 766 ds[z] = tile_embeddings.cpu().numpy() 767 pbar_update(1) 768 769 ds.attrs["original_size"] = original_sizes[-1] 770 ds.attrs["input_size"] = input_sizes[-1] 771 772 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 773 774 return features 775 776 777def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 778 # Check if the embeddings are already cached. 779 if save_path is not None and "input_size" in f.attrs: 780 # In this case we load the embeddings. 781 features = f["features"][:] 782 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 783 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 784 # Also set the embeddings. 785 set_precomputed(predictor, image_embeddings) 786 return image_embeddings 787 788 pbar_init(1, "Compute Image Embeddings 2D") 789 # Otherwise we have to compute the embeddings. 790 predictor.reset_image() 791 predictor.set_image(_to_image(input_)) 792 features = predictor.get_image_embedding().cpu().numpy() 793 original_size = predictor.original_size 794 input_size = predictor.input_size 795 pbar_update(1) 796 797 # Save the embeddings if we have a save_path. 798 if save_path is not None: 799 _create_dataset_with_data(f, "features", data=features) 800 _write_embedding_signature( 801 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 802 ) 803 804 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 805 return image_embeddings 806 807 808def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 809 # Check if the features are already computed. 810 if "input_size" in f.attrs: 811 features = f["features"] 812 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 813 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 814 return image_embeddings 815 816 # Otherwise compute them. Note: saving happens automatically because we 817 # always write the features to zarr. If no save path is given we use an in-memory zarr. 818 features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 819 image_embeddings = {"features": features, "input_size": None, "original_size": None} 820 return image_embeddings 821 822 823def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size): 824 # Check if the embeddings are already fully cached. 825 if save_path is not None and "input_size" in f.attrs: 826 # In this case we load the embeddings. 827 features = f["features"] if lazy_loading else f["features"][:] 828 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 829 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 830 return image_embeddings 831 832 # Otherwise we have to compute the embeddings. 833 834 # First check if we have a save path or not and set things up accordingly. 835 if save_path is None: 836 features = [] 837 save_features = False 838 partial_features = False 839 else: 840 save_features = True 841 embed_shape = (1, 256, 64, 64) 842 shape = (input_.shape[0],) + embed_shape 843 chunks = (1,) + embed_shape 844 if "features" in f: 845 partial_features = True 846 features = f["features"] 847 if features.shape != shape or features.chunks != chunks: 848 raise RuntimeError("Invalid partial features") 849 else: 850 partial_features = False 851 features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32") 852 853 # Initialize the pbar and batches. 854 n_slices = input_.shape[0] 855 pbar_init(n_slices, "Compute Image Embeddings 3D") 856 n_batches = int(np.ceil(n_slices / batch_size)) 857 858 for batch_id in range(n_batches): 859 z_start = batch_id * batch_size 860 z_stop = min(z_start + batch_size, n_slices) 861 862 batched_images, batched_z = [], [] 863 for z in range(z_start, z_stop): 864 # Skip feature computation in case of partial features in non-zero slice. 865 if partial_features and np.count_nonzero(features[z]) != 0: 866 continue 867 tile_input = _to_image(input_[z]) 868 batched_images.append(tile_input) 869 batched_z.append(z) 870 871 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 872 873 for z, embedding in zip(batched_z, batched_embeddings): 874 embedding = embedding.unsqueeze(0) 875 if save_features: 876 features[z] = embedding.cpu().numpy() 877 else: 878 features.append(embedding.unsqueeze(0)) 879 pbar_update(1) 880 881 if save_features: 882 _write_embedding_signature( 883 f, input_, predictor, tile_shape=None, halo=None, 884 input_size=input_sizes[-1], original_size=original_sizes[-1], 885 ) 886 else: 887 # Concatenate across the z axis. 888 features = torch.cat(features).cpu().numpy() 889 890 image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]} 891 return image_embeddings 892 893 894def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 895 # Check if the features are already computed. 896 if "input_size" in f.attrs: 897 features = f["features"] 898 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 899 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 900 return image_embeddings 901 902 # Otherwise compute them. Note: saving happens automatically because we 903 # always write the features to zarr. If no save path is given we use an in-memory zarr. 904 features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 905 image_embeddings = {"features": features, "input_size": None, "original_size": None} 906 return image_embeddings 907 908 909def _compute_data_signature(input_): 910 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 911 return data_signature 912 913 914# Create all metadata that is stored along with the embeddings. 915def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 916 if data_signature is None: 917 data_signature = _compute_data_signature(input_) 918 919 signature = { 920 "data_signature": data_signature, 921 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 922 "halo": halo if halo is None else list(halo), 923 "model_type": predictor.model_type, 924 "model_name": predictor.model_name, 925 "micro_sam_version": __version__, 926 "model_hash": getattr(predictor, "_hash", None), 927 } 928 return signature 929 930 931# Note: the input size and orginal size are different if embeddings are tiled or not. 932# That's why we do not include them in the main signature that is being checked 933# (_get_embedding_signature), but just add it for serialization here. 934def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 935 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 936 signature.update({"input_size": input_size, "original_size": original_size}) 937 for key, val in signature.items(): 938 f.attrs[key] = val 939 940 941def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 942 # We may have an empty zarr file that was already created to save the embeddings in. 943 # In this case the embeddings will be computed and we don't need to perform any checks. 944 if "input_size" not in f.attrs: 945 return 946 947 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 948 for key, val in signature.items(): 949 # Check whether the key is missing from the attrs or if the value is not matching. 950 if key not in f.attrs or f.attrs[key] != val: 951 # These keys were recently added, so we don't want to fail yet if they don't 952 # match in order to not invalidate previous embedding files. 953 # Instead we just raise a warning. (For the version we probably also don't want to fail 954 # i the future since it should not invalidate the embeddings). 955 if key in ("micro_sam_version", "model_hash", "model_name"): 956 warnings.warn( 957 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 958 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 959 "But please recompute them if model predictions don't look as expected." 960 ) 961 else: 962 raise RuntimeError( 963 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 964 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 965 ) 966 967 968# Helper function for optional external progress bars. 969def handle_pbar(verbose, pbar_init, pbar_update): 970 """@private""" 971 972 # Noop to provide dummy functions. 973 def noop(*args): 974 pass 975 976 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 977 assert pbar_update is None # avoid inconsistent state of callbacks 978 979 # Create our own progress bar and callbacks 980 pbar = tqdm() 981 982 def pbar_init(total, description): 983 pbar.total = total 984 pbar.set_description(description) 985 986 def pbar_update(update): 987 pbar.update(update) 988 989 def pbar_close(): 990 pbar.close() 991 992 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 993 assert pbar_update is not None 994 pbar = None 995 pbar_close = noop 996 997 else: # we are not verbose, do nothing 998 pbar = None 999 pbar_init, pbar_update, pbar_close = noop, noop, noop 1000 1001 return pbar, pbar_init, pbar_update, pbar_close 1002 1003 1004def precompute_image_embeddings( 1005 predictor: SamPredictor, 1006 input_: np.ndarray, 1007 save_path: Optional[Union[str, os.PathLike]] = None, 1008 lazy_loading: bool = False, 1009 ndim: Optional[int] = None, 1010 tile_shape: Optional[Tuple[int, int]] = None, 1011 halo: Optional[Tuple[int, int]] = None, 1012 verbose: bool = True, 1013 batch_size: int = 1, 1014 pbar_init: Optional[callable] = None, 1015 pbar_update: Optional[callable] = None, 1016) -> ImageEmbeddings: 1017 """Compute the image embeddings (output of the encoder) for the input. 1018 1019 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1020 1021 Args: 1022 predictor: The Segment Anything predictor. 1023 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1024 save_path: Path to save the embeddings in a zarr container. 1025 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1026 lazy_loading: Whether to load all embeddings into memory or return an 1027 object to load them on demand when required. This only has an effect if 'save_path' is given 1028 and if the input is 3 dimensional. By default, set to 'False'. 1029 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1030 By default, set to 'None', i.e. will be computed from the provided `input_`. 1031 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1032 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1033 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1034 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1035 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1036 Can be used together with pbar_update to handle napari progress bar in other thread. 1037 To enables using this function within a threadworker. 1038 pbar_update: Callback to update an external progress bar. 1039 1040 Returns: 1041 The image embeddings. 1042 """ 1043 ndim = input_.ndim if ndim is None else ndim 1044 1045 # Handle the embedding save_path. 1046 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1047 if save_path is None: 1048 f = zarr.group() 1049 1050 # We have a save path and it already exists. Embeddings will be loaded from it, 1051 # check that the saved embeddings in there match the parameters of the function call. 1052 elif os.path.exists(save_path): 1053 f = zarr.open(save_path, mode="a") 1054 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1055 1056 # We have a save path and it does not exist yet. Create the zarr file to which the 1057 # embeddings will then be saved. 1058 else: 1059 f = zarr.open(save_path, mode="a") 1060 1061 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1062 1063 if ndim == 2 and tile_shape is None: 1064 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1065 elif ndim == 2 and tile_shape is not None: 1066 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1067 elif ndim == 3 and tile_shape is None: 1068 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1069 elif ndim == 3 and tile_shape is not None: 1070 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1071 else: 1072 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1073 1074 pbar_close() 1075 return embeddings 1076 1077 1078def set_precomputed( 1079 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1080) -> SamPredictor: 1081 """Set the precomputed image embeddings for a predictor. 1082 1083 Args: 1084 predictor: The Segment Anything predictor. 1085 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1086 i: Index for the image data. Required if `image` has three spatial dimensions 1087 or a time dimension and two spatial dimensions. 1088 tile_id: Index for the tile. This is required if the embeddings are tiled. 1089 1090 Returns: 1091 The predictor with set features. 1092 """ 1093 if tile_id is not None: 1094 tile_features = image_embeddings["features"][str(tile_id)] 1095 tile_image_embeddings = { 1096 "features": tile_features, 1097 "input_size": tile_features.attrs["input_size"], 1098 "original_size": tile_features.attrs["original_size"] 1099 } 1100 return set_precomputed(predictor, tile_image_embeddings, i=i) 1101 1102 device = predictor.device 1103 features = image_embeddings["features"] 1104 assert features.ndim in (4, 5), f"{features.ndim}" 1105 if features.ndim == 5 and i is None: 1106 raise ValueError("The data is 3D so an index i is needed.") 1107 elif features.ndim == 4 and i is not None: 1108 raise ValueError("The data is 2D so an index is not needed.") 1109 1110 if i is None: 1111 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1112 torch.from_numpy(features[:]).to(device) 1113 else: 1114 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1115 torch.from_numpy(features[i]).to(device) 1116 1117 predictor.original_size = image_embeddings["original_size"] 1118 predictor.input_size = image_embeddings["input_size"] 1119 predictor.is_image_set = True 1120 1121 return predictor 1122 1123 1124# 1125# Misc functionality 1126# 1127 1128 1129def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1130 """Compute the intersection over union of two masks. 1131 1132 Args: 1133 mask1: The first mask. 1134 mask2: The second mask. 1135 1136 Returns: 1137 The intersection over union of the two masks. 1138 """ 1139 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1140 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1141 eps = 1e-7 1142 iou = float(overlap) / (float(union) + eps) 1143 return iou 1144 1145 1146def get_centers_and_bounding_boxes( 1147 segmentation: np.ndarray, mode: str = "v" 1148) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1149 """Returns the center coordinates of the foreground instances in the ground-truth. 1150 1151 Args: 1152 segmentation: The segmentation. 1153 mode: Determines the functionality used for computing the centers. 1154 If 'v', the object's eccentricity centers computed by vigra are used. 1155 If 'p' the object's centroids computed by skimage are used. 1156 1157 Returns: 1158 A dictionary that maps object ids to the corresponding centroid. 1159 A dictionary that maps object_ids to the corresponding bounding box. 1160 """ 1161 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1162 1163 properties = regionprops(segmentation) 1164 1165 if mode == "p": 1166 center_coordinates = {prop.label: prop.centroid for prop in properties} 1167 elif mode == "v": 1168 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1169 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1170 1171 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1172 1173 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1174 return center_coordinates, bbox_coordinates 1175 1176 1177def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1178 """Helper function to load image data from file. 1179 1180 Args: 1181 path: The filepath to the image data. 1182 key: The internal filepath for complex data formats like hdf5. 1183 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1184 1185 Returns: 1186 The image data. 1187 """ 1188 if key is None: 1189 image_data = imageio.imread(path) 1190 else: 1191 with open_file(path, mode="r") as f: 1192 image_data = f[key] 1193 if not lazy_loading: 1194 image_data = image_data[:] 1195 1196 return image_data 1197 1198 1199def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1200 """Convert the segmentation to one-hot encoded masks. 1201 1202 Args: 1203 segmentation: The segmentation. 1204 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1205 By default, computes the number of ids from the provided `segmentation` masks. 1206 1207 Returns: 1208 The one-hot encoded masks. 1209 """ 1210 masks = segmentation.copy() 1211 if segmentation_ids is None: 1212 n_ids = int(segmentation.max()) 1213 1214 else: 1215 msg = "No foreground objects were found." 1216 if len(segmentation_ids) == 0: # The list should not be completely empty. 1217 raise RuntimeError(msg) 1218 1219 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1220 raise RuntimeError(msg) 1221 1222 # the segmentation ids have to be sorted 1223 segmentation_ids = np.sort(segmentation_ids) 1224 1225 # set the non selected objects to zero and relabel sequentially 1226 masks[~np.isin(masks, segmentation_ids)] = 0 1227 masks = relabel_sequential(masks)[0] 1228 n_ids = len(segmentation_ids) 1229 1230 masks = torch.from_numpy(masks) 1231 1232 one_hot_shape = (n_ids + 1,) + masks.shape 1233 masks = masks.unsqueeze(0) # add dimension to scatter 1234 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1235 1236 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1237 masks = masks.unsqueeze(1) 1238 return masks 1239 1240 1241def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1242 """Get a suitable block shape for chunking a given shape. 1243 1244 The primary use for this is determining chunk sizes for 1245 zarr arrays or block shapes for parallelization. 1246 1247 Args: 1248 shape: The image or volume shape. 1249 1250 Returns: 1251 The block shape. 1252 """ 1253 ndim = len(shape) 1254 if ndim == 2: 1255 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1256 elif ndim == 3: 1257 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1258 else: 1259 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1260 1261 return block_shape 1262 1263 1264def micro_sam_info(): 1265 """Display μSAM information using a rich console.""" 1266 import psutil 1267 import platform 1268 import argparse 1269 from rich import progress 1270 from rich.panel import Panel 1271 from rich.table import Table 1272 from rich.console import Console 1273 1274 import torch 1275 import micro_sam 1276 1277 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1278 parser.add_argument( 1279 "--download", nargs="+", metavar=("WHAT", "KIND"), 1280 help="Downloads the pretrained SAM models." 1281 "'--download models' -> downloads all pretrained models; " 1282 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1283 "'--download model/models vit_b_lm' -> downloads a single specified model." 1284 ) 1285 args = parser.parse_args() 1286 1287 # Open up a new console. 1288 console = Console() 1289 1290 # The header for information CLI. 1291 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1292 console.print("-" * console.width) 1293 1294 # μSAM version panel. 1295 console.print( 1296 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1297 ) 1298 1299 # The documentation link panel. 1300 console.print( 1301 Panel( 1302 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1303 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1304 ) 1305 ) 1306 1307 # The publication panel. 1308 console.print( 1309 Panel( 1310 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1311 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1312 ) 1313 ) 1314 1315 # The cache directory panel. 1316 console.print( 1317 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory") 1318 ) 1319 1320 # The available models panel. 1321 available_models = list(get_model_names()) 1322 # We filter out the decoder models. 1323 available_models = [m for m in available_models if not m.endswith("_decoder")] 1324 model_list = "\n".join(available_models) 1325 console.print( 1326 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1327 ) 1328 1329 # The system information table. 1330 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1331 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1332 table.add_column("Property") 1333 table.add_column("Value", style="bold #56B4E9") 1334 table.add_row("System", platform.system()) 1335 table.add_row("Node Name", platform.node()) 1336 table.add_row("Release", platform.release()) 1337 table.add_row("Version", platform.version()) 1338 table.add_row("Machine", platform.machine()) 1339 table.add_row("Processor", platform.processor()) 1340 table.add_row("Platform", platform.platform()) 1341 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1342 console.print(table) 1343 1344 # The device information and check for available GPU acceleration. 1345 default_device = _get_default_device() 1346 1347 if default_device == "cuda": 1348 device_index = torch.cuda.current_device() 1349 device_name = torch.cuda.get_device_name(device_index) 1350 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1351 elif default_device == "mps": 1352 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1353 else: 1354 console.print( 1355 Panel( 1356 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1357 title="Device Information" 1358 ) 1359 ) 1360 1361 # The section allowing to download models. 1362 # NOTE: In future, can be extended to download sample data. 1363 if args.download: 1364 download_provided_args = [t.lower() for t in args.download] 1365 mode, *model_types = download_provided_args 1366 1367 if mode not in {"models", "model"}: 1368 console.print(f"[red]Unknown option for --download: {mode}[/]") 1369 return 1370 1371 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1372 download_list = available_models 1373 else: 1374 download_list = model_types 1375 incorrect_models = [m for m in download_list if m not in available_models] 1376 if incorrect_models: 1377 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1378 return 1379 1380 with progress.Progress( 1381 progress.SpinnerColumn(), 1382 progress.TextColumn("[progress.description]{task.description}"), 1383 progress.BarColumn(bar_width=None), 1384 "[progress.percentage]{task.percentage:>3.0f}%", 1385 progress.TimeRemainingColumn(), 1386 console=console, 1387 ) as prog: 1388 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1389 for model_type in download_list: 1390 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1391 _download_sam_model(model_type=model_type) 1392 prog.advance(task) 1393 1394 console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
60def get_cache_directory() -> None: 61 """Get micro-sam cache directory location. 62 63 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 64 """ 65 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 66 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 67 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
75def microsam_cachedir() -> None: 76 """Return the micro-sam cache directory. 77 78 Returns the top level cache directory for micro-sam models and sample data. 79 80 Every time this function is called, we check for any user updates made to 81 the MICROSAM_CACHEDIR os environment variable since the last time. 82 """ 83 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 84 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
87def models(): 88 """Return the segmentation models registry. 89 90 We recreate the model registry every time this function is called, 91 so any user changes to the default micro-sam cache directory location 92 are respected. 93 """ 94 95 # We use xxhash to compute the hash of the models, see 96 # https://github.com/computational-cell-analytics/micro-sam/issues/283 97 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 98 # To generate the xxh128 hash: 99 # xxh128sum filename 100 encoder_registry = { 101 # The default segment anything models: 102 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 103 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 104 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 105 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 106 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 107 # The current version of our models in the modelzoo. 108 # LM generalist models: 109 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 110 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 111 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 112 # EM models: 113 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 114 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 115 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 116 # Histopathology models: 117 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 118 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 119 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 120 # Medical Imaging models: 121 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 122 } 123 # Additional decoders for instance segmentation. 124 decoder_registry = { 125 # LM generalist models: 126 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 127 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 128 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 129 # EM models: 130 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 131 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 132 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 133 # Histopathology models: 134 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 135 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 136 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 137 } 138 registry = {**encoder_registry, **decoder_registry} 139 140 encoder_urls = { 141 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 142 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 143 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 144 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 145 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 146 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 147 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 148 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 149 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 150 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 151 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 152 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 153 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 154 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 155 } 156 157 decoder_urls = { 158 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 159 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 160 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 161 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 162 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 163 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 164 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 165 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 166 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 167 } 168 urls = {**encoder_urls, **decoder_urls} 169 170 models = pooch.create( 171 path=os.path.join(microsam_cachedir(), "models"), 172 base_url="", 173 registry=registry, 174 urls=urls, 175 ) 176 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 199 """Get the torch device. 200 201 If no device is passed the default device for your system is used. 202 Else it will be checked if the device you have passed is supported. 203 204 Args: 205 device: The input device. By default, selects the best available device supports. 206 207 Returns: 208 The device. 209 """ 210 if device is None or device == "auto": 211 device = _get_default_device() 212 else: 213 device_type = device if isinstance(device, str) else device.type 214 if device_type.lower() == "cuda": 215 if not torch.cuda.is_available(): 216 raise RuntimeError("PyTorch CUDA backend is not available.") 217 elif device_type.lower() == "mps": 218 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 219 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 220 elif device_type.lower() == "cpu": 221 pass # cpu is always available 222 else: 223 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 224 225 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device. By default, selects the best available device supports.
Returns:
The device.
312def get_sam_model( 313 model_type: str = _DEFAULT_MODEL, 314 device: Optional[Union[str, torch.device]] = None, 315 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 316 return_sam: bool = False, 317 return_state: bool = False, 318 peft_kwargs: Optional[Dict] = None, 319 flexible_load_checkpoint: bool = False, 320 progress_bar_factory: Optional[Callable] = None, 321 **model_kwargs, 322) -> SamPredictor: 323 r"""Get the Segment Anything Predictor. 324 325 This function will download the required model or load it from the cached weight file. 326 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 327 The name of the requested model can be set via `model_type`. 328 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 329 for an overview of the available models 330 331 Alternatively this function can also load a model from weights stored in a local filepath. 332 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 333 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 334 a SAM model with vit_b encoder. 335 336 By default the models are downloaded to a folder named 'micro_sam/models' 337 inside your default cache directory, eg: 338 * Mac: ~/Library/Caches/<AppName> 339 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 340 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 341 See the pooch.os_cache() documentation for more details: 342 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 343 344 Args: 345 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 346 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 347 device: The device for the model. If 'None' is provided, will use GPU if available. 348 checkpoint_path: The path to a file with weights that should be used instead of using the 349 weights corresponding to `model_type`. If given, `model_type` must match the architecture 350 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 351 then `model_type` must be given as 'vit_b'. 352 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 353 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 354 peft_kwargs: Keyword arguments for th PEFT wrapper class. 355 If passed 'None', it does not initialize any parameter efficient finetuning. 356 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 357 By default, set to 'False'. 358 progress_bar_factory: A function to create a progress bar for the model download. 359 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 360 361 Returns: 362 The Segment Anything predictor. 363 """ 364 device = get_device(device) 365 366 # We support passing a local filepath to a checkpoint. 367 # In this case we do not download any weights but just use the local weight file, 368 # as it is, without copying it over anywhere or checking it's hashes. 369 370 # checkpoint_path has not been passed, we download a known model and derive the correct 371 # URL from the model_type. If the model_type is invalid pooch will raise an error. 372 _provided_checkpoint_path = checkpoint_path is not None 373 if checkpoint_path is None: 374 checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory) 375 376 # checkpoint_path has been passed, we use it instead of downloading a model. 377 else: 378 # Check if the file exists and raise an error otherwise. 379 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 380 # (If it isn't the model creation will fail below.) 381 if not os.path.exists(checkpoint_path): 382 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 383 model_hash = _compute_hash(checkpoint_path) 384 decoder_path = None 385 386 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 387 # before calling sam_model_registry. 388 abbreviated_model_type = model_type[:5] 389 if abbreviated_model_type not in _MODEL_TYPES: 390 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 391 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 392 raise RuntimeError( 393 "'mobile_sam' is required for the vit-tiny. " 394 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 395 ) 396 397 state, model_state = _load_checkpoint(checkpoint_path) 398 399 if _provided_checkpoint_path: 400 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 401 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 402 from micro_sam.models.build_sam import _validate_model_type 403 _provided_model_type = _validate_model_type(model_state) 404 405 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 406 # Otherwise replace 'abbreviated_model_type' with the later. 407 if abbreviated_model_type != _provided_model_type: 408 # Printing the message below to avoid any filtering of warnings on user's end. 409 print( 410 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 411 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 412 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 413 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 414 ) 415 416 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 417 abbreviated_model_type = _provided_model_type 418 419 # Whether to update parameters necessary to initialize the model 420 if model_kwargs: # Checks whether model_kwargs have been provided or not 421 if abbreviated_model_type == "vit_t": 422 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 423 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 424 425 else: 426 sam = sam_model_registry[abbreviated_model_type]() 427 428 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 429 # Overwrites the SAM model by freezing the backbone and allow PEFT. 430 if peft_kwargs and isinstance(peft_kwargs, dict): 431 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 432 peft_kwargs.pop("quantize", None) 433 434 if abbreviated_model_type == "vit_t": 435 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 436 437 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 438 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 439 if flexible_load_checkpoint: 440 sam = _handle_checkpoint_loading(sam, model_state) 441 else: 442 sam.load_state_dict(model_state) 443 sam.to(device=device) 444 445 predictor = SamPredictor(sam) 446 predictor.model_type = abbreviated_model_type 447 predictor._hash = model_hash 448 predictor.model_name = model_type 449 predictor.checkpoint_path = checkpoint_path 450 451 # Add the decoder to the state if we have one and if the state is returned. 452 if decoder_path is not None and return_state: 453 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 454 455 if return_sam and return_state: 456 return predictor, sam, state 457 if return_sam: 458 return predictor, sam 459 if return_state: 460 return predictor, state 461 return predictor
Get the Segment Anything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type
.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path
. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
To get a list of all available model names you can call
get_model_names
. - device: The device for the model. If 'None' is provided, will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type
. If given,model_type
must match the architecture corresponding to the weight file. e.g. if you use weights for SAM withvit_b
encoder thenmodel_type
must be given as 'vit_b'. - return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
- return_state: Return the unpickled checkpoint state. By default, set to 'False'.
- peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
- progress_bar_factory: A function to create a progress bar for the model download.
- model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:
The Segment Anything predictor.
497def export_custom_sam_model( 498 checkpoint_path: Union[str, os.PathLike], 499 model_type: str, 500 save_path: Union[str, os.PathLike], 501 with_segmentation_decoder: bool = False, 502 prefix: str = "sam.", 503) -> None: 504 """Export a finetuned Segment Anything Model to the standard model format. 505 506 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 507 508 Args: 509 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 510 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 511 save_path: Where to save the exported model. 512 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 513 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 514 prefix: The prefix to remove from the model parameter keys. 515 """ 516 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 517 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 518 519 # Store the 'decoder_state' as well, if desired. 520 if with_segmentation_decoder: 521 if "decoder_state" not in state: 522 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 523 decoder_state = state["decoder_state"] 524 save_state = {"model_state": model_state, "decoder_state": decoder_state} 525 else: 526 save_state = model_state 527 528 torch.save(save_state, save_path)
Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator
.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
- with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
- prefix: The prefix to remove from the model parameter keys.
531def export_custom_qlora_model( 532 checkpoint_path: Optional[Union[str, os.PathLike]], 533 finetuned_path: Union[str, os.PathLike], 534 model_type: str, 535 save_path: Union[str, os.PathLike], 536) -> None: 537 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 538 539 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 540 541 Args: 542 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 543 finetuned_path: The path to the new finetuned model, using QLoRA. 544 model_type: The Segment Anything Model type corresponding to the checkpoint. 545 save_path: Where to save the exported model. 546 """ 547 # Step 1: Get the base SAM model: used to start finetuning from. 548 _, sam = get_sam_model( 549 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 550 ) 551 552 # Step 2: Load the QLoRA-style finetuned model. 553 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 554 555 # Step 3: Identify LoRA layers from QLoRA model. 556 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 557 # - then copy the LoRA layers from the QLoRA model to the new state dict 558 updated_model_state = {} 559 560 modified_attn_layers = set() 561 modified_mlp_layers = set() 562 563 for k, v in ft_model_state.items(): 564 if "blocks." in k: 565 layer_id = int(k.split("blocks.")[1].split(".")[0]) 566 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 567 modified_attn_layers.add(layer_id) 568 updated_model_state[k] = v 569 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 570 modified_mlp_layers.add(layer_id) 571 updated_model_state[k] = v 572 573 # Step 4: Next, we get all the remaining parameters from the base SAM model. 574 for k, v in sam.state_dict().items(): 575 if "blocks." in k: 576 layer_id = int(k.split("blocks.")[1].split(".")[0]) 577 if k.find("attn.qkv.") != -1: 578 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 579 k = k.replace("qkv", "qkv.qkv_proj") 580 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 581 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 582 k = k.replace("mlp.", "mlp.mlp_layer.") 583 updated_model_state[k] = v 584 585 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 586 ft_state['model_state'] = updated_model_state 587 588 # Step 6: Store the new "state" to "save_path" 589 torch.save(ft_state, save_path)
Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs
to get_sam_model
.
Arguments:
- checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
- finetuned_path: The path to the new finetuned model, using QLoRA.
- model_type: The Segment Anything Model type corresponding to the checkpoint.
- save_path: Where to save the exported model.
1005def precompute_image_embeddings( 1006 predictor: SamPredictor, 1007 input_: np.ndarray, 1008 save_path: Optional[Union[str, os.PathLike]] = None, 1009 lazy_loading: bool = False, 1010 ndim: Optional[int] = None, 1011 tile_shape: Optional[Tuple[int, int]] = None, 1012 halo: Optional[Tuple[int, int]] = None, 1013 verbose: bool = True, 1014 batch_size: int = 1, 1015 pbar_init: Optional[callable] = None, 1016 pbar_update: Optional[callable] = None, 1017) -> ImageEmbeddings: 1018 """Compute the image embeddings (output of the encoder) for the input. 1019 1020 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1021 1022 Args: 1023 predictor: The Segment Anything predictor. 1024 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1025 save_path: Path to save the embeddings in a zarr container. 1026 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1027 lazy_loading: Whether to load all embeddings into memory or return an 1028 object to load them on demand when required. This only has an effect if 'save_path' is given 1029 and if the input is 3 dimensional. By default, set to 'False'. 1030 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1031 By default, set to 'None', i.e. will be computed from the provided `input_`. 1032 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1033 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1034 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1035 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1036 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1037 Can be used together with pbar_update to handle napari progress bar in other thread. 1038 To enables using this function within a threadworker. 1039 pbar_update: Callback to update an external progress bar. 1040 1041 Returns: 1042 The image embeddings. 1043 """ 1044 ndim = input_.ndim if ndim is None else ndim 1045 1046 # Handle the embedding save_path. 1047 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1048 if save_path is None: 1049 f = zarr.group() 1050 1051 # We have a save path and it already exists. Embeddings will be loaded from it, 1052 # check that the saved embeddings in there match the parameters of the function call. 1053 elif os.path.exists(save_path): 1054 f = zarr.open(save_path, mode="a") 1055 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1056 1057 # We have a save path and it does not exist yet. Create the zarr file to which the 1058 # embeddings will then be saved. 1059 else: 1060 f = zarr.open(save_path, mode="a") 1061 1062 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1063 1064 if ndim == 2 and tile_shape is None: 1065 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1066 elif ndim == 2 and tile_shape is not None: 1067 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1068 elif ndim == 3 and tile_shape is None: 1069 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1070 elif ndim == 3 and tile_shape is not None: 1071 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1072 else: 1073 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1074 1075 pbar_close() 1076 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The Segment Anything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
By default, set to 'None', i.e. will be computed from the provided
input_
. - tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Whether to be verbose in the computation. By default, set to 'True'.
- batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
1079def set_precomputed( 1080 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1081) -> SamPredictor: 1082 """Set the precomputed image embeddings for a predictor. 1083 1084 Args: 1085 predictor: The Segment Anything predictor. 1086 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1087 i: Index for the image data. Required if `image` has three spatial dimensions 1088 or a time dimension and two spatial dimensions. 1089 tile_id: Index for the tile. This is required if the embeddings are tiled. 1090 1091 Returns: 1092 The predictor with set features. 1093 """ 1094 if tile_id is not None: 1095 tile_features = image_embeddings["features"][str(tile_id)] 1096 tile_image_embeddings = { 1097 "features": tile_features, 1098 "input_size": tile_features.attrs["input_size"], 1099 "original_size": tile_features.attrs["original_size"] 1100 } 1101 return set_precomputed(predictor, tile_image_embeddings, i=i) 1102 1103 device = predictor.device 1104 features = image_embeddings["features"] 1105 assert features.ndim in (4, 5), f"{features.ndim}" 1106 if features.ndim == 5 and i is None: 1107 raise ValueError("The data is 3D so an index i is needed.") 1108 elif features.ndim == 4 and i is not None: 1109 raise ValueError("The data is 2D so an index is not needed.") 1110 1111 if i is None: 1112 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1113 torch.from_numpy(features[:]).to(device) 1114 else: 1115 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1116 torch.from_numpy(features[i]).to(device) 1117 1118 predictor.original_size = image_embeddings["original_size"] 1119 predictor.input_size = image_embeddings["input_size"] 1120 predictor.is_image_set = True 1121 1122 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings
. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
1130def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1131 """Compute the intersection over union of two masks. 1132 1133 Args: 1134 mask1: The first mask. 1135 mask2: The second mask. 1136 1137 Returns: 1138 The intersection over union of the two masks. 1139 """ 1140 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1141 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1142 eps = 1e-7 1143 iou = float(overlap) / (float(union) + eps) 1144 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
1147def get_centers_and_bounding_boxes( 1148 segmentation: np.ndarray, mode: str = "v" 1149) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1150 """Returns the center coordinates of the foreground instances in the ground-truth. 1151 1152 Args: 1153 segmentation: The segmentation. 1154 mode: Determines the functionality used for computing the centers. 1155 If 'v', the object's eccentricity centers computed by vigra are used. 1156 If 'p' the object's centroids computed by skimage are used. 1157 1158 Returns: 1159 A dictionary that maps object ids to the corresponding centroid. 1160 A dictionary that maps object_ids to the corresponding bounding box. 1161 """ 1162 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1163 1164 properties = regionprops(segmentation) 1165 1166 if mode == "p": 1167 center_coordinates = {prop.label: prop.centroid for prop in properties} 1168 elif mode == "v": 1169 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1170 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1171 1172 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1173 1174 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1175 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
1178def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1179 """Helper function to load image data from file. 1180 1181 Args: 1182 path: The filepath to the image data. 1183 key: The internal filepath for complex data formats like hdf5. 1184 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1185 1186 Returns: 1187 The image data. 1188 """ 1189 if key is None: 1190 image_data = imageio.imread(path) 1191 else: 1192 with open_file(path, mode="r") as f: 1193 image_data = f[key] 1194 if not lazy_loading: 1195 image_data = image_data[:] 1196 1197 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
1200def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1201 """Convert the segmentation to one-hot encoded masks. 1202 1203 Args: 1204 segmentation: The segmentation. 1205 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1206 By default, computes the number of ids from the provided `segmentation` masks. 1207 1208 Returns: 1209 The one-hot encoded masks. 1210 """ 1211 masks = segmentation.copy() 1212 if segmentation_ids is None: 1213 n_ids = int(segmentation.max()) 1214 1215 else: 1216 msg = "No foreground objects were found." 1217 if len(segmentation_ids) == 0: # The list should not be completely empty. 1218 raise RuntimeError(msg) 1219 1220 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1221 raise RuntimeError(msg) 1222 1223 # the segmentation ids have to be sorted 1224 segmentation_ids = np.sort(segmentation_ids) 1225 1226 # set the non selected objects to zero and relabel sequentially 1227 masks[~np.isin(masks, segmentation_ids)] = 0 1228 masks = relabel_sequential(masks)[0] 1229 n_ids = len(segmentation_ids) 1230 1231 masks = torch.from_numpy(masks) 1232 1233 one_hot_shape = (n_ids + 1,) + masks.shape 1234 masks = masks.unsqueeze(0) # add dimension to scatter 1235 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1236 1237 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1238 masks = masks.unsqueeze(1) 1239 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
By default, computes the number of ids from the provided
segmentation
masks.
Returns:
The one-hot encoded masks.
1242def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1243 """Get a suitable block shape for chunking a given shape. 1244 1245 The primary use for this is determining chunk sizes for 1246 zarr arrays or block shapes for parallelization. 1247 1248 Args: 1249 shape: The image or volume shape. 1250 1251 Returns: 1252 The block shape. 1253 """ 1254 ndim = len(shape) 1255 if ndim == 2: 1256 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1257 elif ndim == 3: 1258 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1259 else: 1260 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1261 1262 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.
1265def micro_sam_info(): 1266 """Display μSAM information using a rich console.""" 1267 import psutil 1268 import platform 1269 import argparse 1270 from rich import progress 1271 from rich.panel import Panel 1272 from rich.table import Table 1273 from rich.console import Console 1274 1275 import torch 1276 import micro_sam 1277 1278 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1279 parser.add_argument( 1280 "--download", nargs="+", metavar=("WHAT", "KIND"), 1281 help="Downloads the pretrained SAM models." 1282 "'--download models' -> downloads all pretrained models; " 1283 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1284 "'--download model/models vit_b_lm' -> downloads a single specified model." 1285 ) 1286 args = parser.parse_args() 1287 1288 # Open up a new console. 1289 console = Console() 1290 1291 # The header for information CLI. 1292 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1293 console.print("-" * console.width) 1294 1295 # μSAM version panel. 1296 console.print( 1297 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1298 ) 1299 1300 # The documentation link panel. 1301 console.print( 1302 Panel( 1303 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1304 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1305 ) 1306 ) 1307 1308 # The publication panel. 1309 console.print( 1310 Panel( 1311 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1312 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1313 ) 1314 ) 1315 1316 # The cache directory panel. 1317 console.print( 1318 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory") 1319 ) 1320 1321 # The available models panel. 1322 available_models = list(get_model_names()) 1323 # We filter out the decoder models. 1324 available_models = [m for m in available_models if not m.endswith("_decoder")] 1325 model_list = "\n".join(available_models) 1326 console.print( 1327 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1328 ) 1329 1330 # The system information table. 1331 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1332 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1333 table.add_column("Property") 1334 table.add_column("Value", style="bold #56B4E9") 1335 table.add_row("System", platform.system()) 1336 table.add_row("Node Name", platform.node()) 1337 table.add_row("Release", platform.release()) 1338 table.add_row("Version", platform.version()) 1339 table.add_row("Machine", platform.machine()) 1340 table.add_row("Processor", platform.processor()) 1341 table.add_row("Platform", platform.platform()) 1342 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1343 console.print(table) 1344 1345 # The device information and check for available GPU acceleration. 1346 default_device = _get_default_device() 1347 1348 if default_device == "cuda": 1349 device_index = torch.cuda.current_device() 1350 device_name = torch.cuda.get_device_name(device_index) 1351 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1352 elif default_device == "mps": 1353 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1354 else: 1355 console.print( 1356 Panel( 1357 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1358 title="Device Information" 1359 ) 1360 ) 1361 1362 # The section allowing to download models. 1363 # NOTE: In future, can be extended to download sample data. 1364 if args.download: 1365 download_provided_args = [t.lower() for t in args.download] 1366 mode, *model_types = download_provided_args 1367 1368 if mode not in {"models", "model"}: 1369 console.print(f"[red]Unknown option for --download: {mode}[/]") 1370 return 1371 1372 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1373 download_list = available_models 1374 else: 1375 download_list = model_types 1376 incorrect_models = [m for m in download_list if m not in available_models] 1377 if incorrect_models: 1378 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1379 return 1380 1381 with progress.Progress( 1382 progress.SpinnerColumn(), 1383 progress.TextColumn("[progress.description]{task.description}"), 1384 progress.BarColumn(bar_width=None), 1385 "[progress.percentage]{task.percentage:>3.0f}%", 1386 progress.TimeRemainingColumn(), 1387 console=console, 1388 ) as prog: 1389 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1390 for model_type in download_list: 1391 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1392 _download_sam_model(model_type=model_type) 1393 prog.advance(task) 1394 1395 console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
Display μSAM information using a rich console.