micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5import os 6import pickle 7import hashlib 8import warnings 9from pathlib import Path 10from collections import OrderedDict 11from typing import Any, Dict, Iterable, Optional, Tuple, Union 12 13import zarr 14import vigra 15import torch 16import pooch 17import xxhash 18import numpy as np 19import imageio.v3 as imageio 20from skimage.measure import regionprops 21from skimage.segmentation import relabel_sequential 22 23from elf.io import open_file 24 25from nifty.tools import blocking 26 27from .__version__ import __version__ 28from . import models as custom_models 29 30try: 31 # Avoid import warnigns from mobile_sam 32 with warnings.catch_warnings(): 33 warnings.simplefilter("ignore") 34 from mobile_sam import sam_model_registry, SamPredictor 35 VIT_T_SUPPORT = True 36except ImportError: 37 from segment_anything import sam_model_registry, SamPredictor 38 VIT_T_SUPPORT = False 39 40try: 41 from napari.utils import progress as tqdm 42except ImportError: 43 from tqdm import tqdm 44 45# This is the default model used in micro_sam 46# Currently it is set to vit_b_lm 47_DEFAULT_MODEL = "vit_b_lm" 48 49# The valid model types. Each type corresponds to the architecture of the 50# vision transformer used within SAM. 51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 52 53 54# TODO define the proper type for image embeddings 55ImageEmbeddings = Dict[str, Any] 56"""@private""" 57 58 59def get_cache_directory() -> None: 60 """Get micro-sam cache directory location. 61 62 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 63 """ 64 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 65 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 66 return cache_directory 67 68 69# 70# Functionality for model download and export 71# 72 73 74def microsam_cachedir() -> None: 75 """Return the micro-sam cache directory. 76 77 Returns the top level cache directory for micro-sam models and sample data. 78 79 Every time this function is called, we check for any user updates made to 80 the MICROSAM_CACHEDIR os environment variable since the last time. 81 """ 82 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 83 return cache_directory 84 85 86def models(): 87 """Return the segmentation models registry. 88 89 We recreate the model registry every time this function is called, 90 so any user changes to the default micro-sam cache directory location 91 are respected. 92 """ 93 94 # We use xxhash to compute the hash of the models, see 95 # https://github.com/computational-cell-analytics/micro-sam/issues/283 96 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 97 # To generate the xxh128 hash: 98 # xxh128sum filename 99 encoder_registry = { 100 # The default segment anything models: 101 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 102 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 103 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 104 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 105 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 106 # The current version of our models in the modelzoo. 107 # LM generalist models: 108 "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2", 109 "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f", 110 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 111 # EM models: 112 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 113 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 114 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 115 # Histopathology models: 116 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 117 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 118 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 119 # Medical Imaging models: 120 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 121 } 122 # Additional decoders for instance segmentation. 123 decoder_registry = { 124 # LM generalist models: 125 "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f", 126 "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d", 127 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 128 # EM models: 129 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 130 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 131 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 132 # Histopathology models: 133 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 134 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 135 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 136 } 137 registry = {**encoder_registry, **decoder_registry} 138 139 encoder_urls = { 140 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 141 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 142 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 143 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 144 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt", 145 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt", 146 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 147 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 148 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 149 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 150 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 151 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 152 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 153 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 154 } 155 156 decoder_urls = { 157 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt", # noqa 158 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt", # noqa 159 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 160 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 161 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 162 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 163 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 164 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 165 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 166 } 167 urls = {**encoder_urls, **decoder_urls} 168 169 models = pooch.create( 170 path=os.path.join(microsam_cachedir(), "models"), 171 base_url="", 172 registry=registry, 173 urls=urls, 174 ) 175 return models 176 177 178def _get_default_device(): 179 # check that we're in CI and use the CPU if we are 180 # otherwise the tests may run out of memory on MAC if MPS is used. 181 if os.getenv("GITHUB_ACTIONS") == "true": 182 return "cpu" 183 # Use cuda enabled gpu if it's available. 184 if torch.cuda.is_available(): 185 device = "cuda" 186 # As second priority use mps. 187 # See https://pytorch.org/docs/stable/notes/mps.html for details 188 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 189 print("Using apple MPS device.") 190 device = "mps" 191 # Use the CPU as fallback. 192 else: 193 device = "cpu" 194 return device 195 196 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 198 """Get the torch device. 199 200 If no device is passed the default device for your system is used. 201 Else it will be checked if the device you have passed is supported. 202 203 Args: 204 device: The input device. 205 206 Returns: 207 The device. 208 """ 209 if device is None or device == "auto": 210 device = _get_default_device() 211 else: 212 device_type = device if isinstance(device, str) else device.type 213 if device_type.lower() == "cuda": 214 if not torch.cuda.is_available(): 215 raise RuntimeError("PyTorch CUDA backend is not available.") 216 elif device_type.lower() == "mps": 217 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 218 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 219 elif device_type.lower() == "cpu": 220 pass # cpu is always available 221 else: 222 raise RuntimeError( 223 f"Unsupported device: {device}\n" 224 "Please choose from 'cpu', 'cuda', or 'mps'." 225 ) 226 227 return device 228 229 230def _available_devices(): 231 available_devices = [] 232 for i in ["cuda", "mps", "cpu"]: 233 try: 234 device = get_device(i) 235 except RuntimeError: 236 pass 237 else: 238 available_devices.append(device) 239 return available_devices 240 241 242# We write a custom unpickler that skips objects that cannot be found instead of 243# throwing an AttributeError or ModueNotFoundError. 244# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 245# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 246class _CustomUnpickler(pickle.Unpickler): 247 def find_class(self, module, name): 248 try: 249 return super().find_class(module, name) 250 except (AttributeError, ModuleNotFoundError) as e: 251 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 252 return None 253 254 255def _compute_hash(path, chunk_size=8192): 256 hash_obj = xxhash.xxh128() 257 with open(path, "rb") as f: 258 chunk = f.read(chunk_size) 259 while chunk: 260 hash_obj.update(chunk) 261 chunk = f.read(chunk_size) 262 hash_val = hash_obj.hexdigest() 263 return f"xxh128:{hash_val}" 264 265 266# Load the state from a checkpoint. 267# The checkpoint can either contain a sam encoder state 268# or it can be a checkpoint for model finetuning. 269def _load_checkpoint(checkpoint_path): 270 # Over-ride the unpickler with our custom one. 271 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 272 custom_pickle = pickle 273 custom_pickle.Unpickler = _CustomUnpickler 274 275 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 276 if "model_state" in state: 277 # Copy the model weights from torch_em's training format. 278 model_state = state["model_state"] 279 sam_prefix = "sam." 280 model_state = OrderedDict( 281 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 282 ) 283 else: 284 model_state = state 285 286 return state, model_state 287 288 289def get_sam_model( 290 model_type: str = _DEFAULT_MODEL, 291 device: Optional[Union[str, torch.device]] = None, 292 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 293 return_sam: bool = False, 294 return_state: bool = False, 295 peft_kwargs: Optional[Dict] = None, 296 flexible_load_checkpoint: bool = False, 297 **model_kwargs, 298) -> SamPredictor: 299 r"""Get the SegmentAnything Predictor. 300 301 This function will download the required model or load it from the cached weight file. 302 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 303 The name of the requested model can be set via `model_type`. 304 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 305 for an overview of the available models 306 307 Alternatively this function can also load a model from weights stored in a local filepath. 308 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 309 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 310 a SAM model with vit_b encoder. 311 312 By default the models are downloaded to a folder named 'micro_sam/models' 313 inside your default cache directory, eg: 314 * Mac: ~/Library/Caches/<AppName> 315 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 316 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 317 See the pooch.os_cache() documentation for more details: 318 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 319 320 Args: 321 model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. 322 To get a list of all available model names you can call `get_model_names`. 323 device: The device for the model. If none is given will use GPU if available. 324 checkpoint_path: The path to a file with weights that should be used instead of using the 325 weights corresponding to `model_type`. If given, `model_type` must match the architecture 326 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 327 then `model_type` must be given as "vit_b". 328 return_sam: Return the sam model object as well as the predictor. 329 return_state: Return the unpickled checkpoint state. 330 peft_kwargs: Keyword arguments for th PEFT wrapper class. 331 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 332 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 333 334 Returns: 335 The segment anything predictor. 336 """ 337 device = get_device(device) 338 339 # We support passing a local filepath to a checkpoint. 340 # In this case we do not download any weights but just use the local weight file, 341 # as it is, without copying it over anywhere or checking it's hashes. 342 343 # checkpoint_path has not been passed, we download a known model and derive the correct 344 # URL from the model_type. If the model_type is invalid pooch will raise an error. 345 if checkpoint_path is None: 346 model_registry = models() 347 checkpoint_path = model_registry.fetch(model_type, progressbar=True) 348 model_hash = model_registry.registry[model_type] 349 350 # If we have a custom model then we may also have a decoder checkpoint. 351 # Download it here, so that we can add it to the state. 352 decoder_name = f"{model_type}_decoder" 353 decoder_path = model_registry.fetch( 354 decoder_name, progressbar=True 355 ) if decoder_name in model_registry.registry else None 356 357 # checkpoint_path has been passed, we use it instead of downloading a model. 358 else: 359 # Check if the file exists and raise an error otherwise. 360 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 361 # (If it isn't the model creation will fail below.) 362 if not os.path.exists(checkpoint_path): 363 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 364 model_hash = _compute_hash(checkpoint_path) 365 decoder_path = None 366 367 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 368 # before calling sam_model_registry. 369 abbreviated_model_type = model_type[:5] 370 if abbreviated_model_type not in _MODEL_TYPES: 371 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 372 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 373 raise RuntimeError( 374 "'mobile_sam' is required for the vit-tiny. " 375 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 376 ) 377 378 state, model_state = _load_checkpoint(checkpoint_path) 379 380 # Whether to update parameters necessary to initialize the model 381 if model_kwargs: # Checks whether model_kwargs have been provided or not 382 if abbreviated_model_type == "vit_t": 383 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 384 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 385 386 else: 387 sam = sam_model_registry[abbreviated_model_type]() 388 389 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 390 # Overwrites the SAM model by freezing the backbone and allow PEFT. 391 if peft_kwargs and isinstance(peft_kwargs, dict): 392 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 393 peft_kwargs.pop("quantize", None) 394 395 if abbreviated_model_type == "vit_t": 396 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 397 398 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 399 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 400 if flexible_load_checkpoint: 401 sam = _handle_checkpoint_loading(sam, model_state) 402 else: 403 sam.load_state_dict(model_state) 404 sam.to(device=device) 405 406 predictor = SamPredictor(sam) 407 predictor.model_type = abbreviated_model_type 408 predictor._hash = model_hash 409 predictor.model_name = model_type 410 411 # Add the decoder to the state if we have one and if the state is returned. 412 if decoder_path is not None and return_state: 413 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 414 415 if return_sam and return_state: 416 return predictor, sam, state 417 if return_sam: 418 return predictor, sam 419 if return_state: 420 return predictor, state 421 return predictor 422 423 424def _handle_checkpoint_loading(sam, model_state): 425 # Whether to handle the mismatch issues in a bit more elegant way. 426 # eg. while training for multi-class semantic segmentation in the mask encoder, 427 # parameters are updated - leading to "size mismatch" errors 428 429 new_state_dict = {} # for loading matching parameters 430 mismatched_layers = [] # for tracking mismatching parameters 431 432 reference_state = sam.state_dict() 433 434 for k, v in model_state.items(): 435 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 436 if reference_state[k].size() == v.size(): 437 new_state_dict[k] = v 438 else: 439 mismatched_layers.append(k) 440 441 reference_state.update(new_state_dict) 442 443 if len(mismatched_layers) > 0: 444 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 445 446 for mlayer in mismatched_layers: 447 if 'weight' in mlayer: 448 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 449 elif 'bias' in mlayer: 450 reference_state[mlayer].zero_() 451 452 sam.load_state_dict(reference_state) 453 454 return sam 455 456 457def export_custom_sam_model( 458 checkpoint_path: Union[str, os.PathLike], 459 model_type: str, 460 save_path: Union[str, os.PathLike], 461 with_segmentation_decoder: bool = False, 462) -> None: 463 """Export a finetuned Segment Anything Model to the standard model format. 464 465 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 466 467 Args: 468 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 469 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 470 save_path: Where to save the exported model. 471 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 472 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 473 """ 474 _, state = get_sam_model( 475 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 476 ) 477 model_state = state["model_state"] 478 prefix = "sam." 479 model_state = OrderedDict( 480 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 481 ) 482 483 # Store the 'decoder_state' as well, if desired. 484 if with_segmentation_decoder: 485 if "decoder_state" not in state: 486 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 487 decoder_state = state["decoder_state"] 488 save_state = {"model_state": model_state, "decoder_state": decoder_state} 489 else: 490 save_state = model_state 491 492 torch.save(save_state, save_path) 493 494 495def export_custom_qlora_model( 496 checkpoint_path: Union[str, os.PathLike], 497 finetuned_path: Union[str, os.PathLike], 498 model_type: str, 499 save_path: Union[str, os.PathLike], 500) -> None: 501 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 502 503 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 504 505 Args: 506 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 507 finetuned_path: The path to the new finetuned model, using QLoRA. 508 model_type: The Segment Anything Model type corresponding to the checkpoint. 509 save_path: Where to save the exported model. 510 """ 511 # Step 1: Get the base SAM model: used to start finetuning from. 512 _, sam = get_sam_model( 513 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 514 ) 515 516 # Step 2: Load the QLoRA-style finetuned model. 517 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 518 519 # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model. 520 updated_model_state = {} 521 522 # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint. 523 for k, v in ft_model_state.items(): 524 if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1: 525 updated_model_state[k] = v 526 527 # - Next, we get all the remaining parameters from the base SAM model. 528 for k, v in sam.state_dict().items(): 529 if k.find("attn.qkv.") != -1: 530 k = k.replace("qkv", "qkv.qkv_proj") 531 updated_model_state[k] = v 532 else: 533 534 updated_model_state[k] = v 535 536 # - Finally, we replace the old model state with the new one (to retain other relevant stuff) 537 ft_state['model_state'] = updated_model_state 538 539 # Step 4: Store the new "state" to "save_path" 540 torch.save(ft_state, save_path) 541 542 543def get_model_names() -> Iterable: 544 model_registry = models() 545 model_names = model_registry.registry.keys() 546 return model_names 547 548 549# 550# Functionality for precomputing image embeddings. 551# 552 553 554def _to_image(input_): 555 # we require the input to be uint8 556 if input_.dtype != np.dtype("uint8"): 557 # first normalize the input to [0, 1] 558 input_ = input_.astype("float32") - input_.min() 559 input_ = input_ / input_.max() 560 # then bring to [0, 255] and cast to uint8 561 input_ = (input_ * 255).astype("uint8") 562 563 if input_.ndim == 2: 564 image = np.concatenate([input_[..., None]] * 3, axis=-1) 565 elif input_.ndim == 3 and input_.shape[-1] == 3: 566 image = input_ 567 else: 568 raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") 569 570 return image 571 572 573def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update): 574 tiling = blocking([0, 0], input_.shape[:2], tile_shape) 575 n_tiles = tiling.numberOfBlocks 576 577 features = f.require_group("features") 578 features.attrs["shape"] = input_.shape[:2] 579 features.attrs["tile_shape"] = tile_shape 580 features.attrs["halo"] = halo 581 582 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled.") 583 for tile_id in range(n_tiles): 584 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 585 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 586 587 predictor.reset_image() 588 tile_input = _to_image(input_[outer_tile]) 589 predictor.set_image(tile_input) 590 tile_features = predictor.get_image_embedding() 591 original_size = predictor.original_size 592 input_size = predictor.input_size 593 594 ds = features.create_dataset( 595 str(tile_id), data=tile_features.cpu().numpy(), compression="gzip", chunks=tile_features.shape 596 ) 597 ds.attrs["original_size"] = original_size 598 ds.attrs["input_size"] = input_size 599 pbar_update(1) 600 601 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 602 603 return features 604 605 606def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update): 607 assert input_.ndim == 3 608 609 shape = input_.shape[1:] 610 tiling = blocking([0, 0], shape, tile_shape) 611 n_tiles = tiling.numberOfBlocks 612 613 features = f.require_group("features") 614 features.attrs["shape"] = shape 615 features.attrs["tile_shape"] = tile_shape 616 features.attrs["halo"] = halo 617 618 n_slices = input_.shape[0] 619 pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled.") 620 621 for tile_id in range(n_tiles): 622 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 623 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 624 625 ds = None 626 for z in range(n_slices): 627 predictor.reset_image() 628 tile_input = _to_image(input_[z][outer_tile]) 629 predictor.set_image(tile_input) 630 tile_features = predictor.get_image_embedding() 631 632 if ds is None: 633 shape = (input_.shape[0],) + tile_features.shape 634 chunks = (1,) + tile_features.shape 635 ds = features.create_dataset( 636 str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks 637 ) 638 639 ds[z] = tile_features.cpu().numpy() 640 pbar_update(1) 641 642 original_size = predictor.original_size 643 input_size = predictor.input_size 644 645 ds.attrs["original_size"] = original_size 646 ds.attrs["input_size"] = input_size 647 648 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 649 650 return features 651 652 653def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 654 # Check if the embeddings are already cached. 655 if save_path is not None and "input_size" in f.attrs: 656 # In this case we load the embeddings. 657 features = f["features"][:] 658 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 659 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 660 # Also set the embeddings. 661 set_precomputed(predictor, image_embeddings) 662 return image_embeddings 663 664 pbar_init(1, "Compute Image Embeddings 2D.") 665 # Otherwise we have to compute the embeddings. 666 predictor.reset_image() 667 predictor.set_image(_to_image(input_)) 668 features = predictor.get_image_embedding().cpu().numpy() 669 original_size = predictor.original_size 670 input_size = predictor.input_size 671 pbar_update(1) 672 673 # Save the embeddings if we have a save_path. 674 if save_path is not None: 675 f.create_dataset("features", data=features, compression="gzip", chunks=features.shape) 676 _write_embedding_signature( 677 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 678 ) 679 680 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 681 return image_embeddings 682 683 684def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update): 685 # Check if the features are already computed. 686 if "input_size" in f.attrs: 687 features = f["features"] 688 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 689 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 690 return image_embeddings 691 692 # Otherwise compute them. Note: saving happens automatically because we 693 # always write the features to zarr. If no save path is given we use an in-memory zarr. 694 features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update) 695 image_embeddings = {"features": features, "input_size": None, "original_size": None} 696 return image_embeddings 697 698 699def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update): 700 # Check if the embeddings are already fully cached. 701 if save_path is not None and "input_size" in f.attrs: 702 # In this case we load the embeddings. 703 features = f["features"] if lazy_loading else f["features"][:] 704 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 705 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 706 return image_embeddings 707 708 # Otherwise we have to compute the embeddings. 709 710 # First check if we have a save path or not and set things up accordingly. 711 if save_path is None: 712 features = [] 713 save_features = False 714 partial_features = False 715 else: 716 save_features = True 717 embed_shape = (1, 256, 64, 64) 718 shape = (input_.shape[0],) + embed_shape 719 chunks = (1,) + embed_shape 720 if "features" in f: 721 partial_features = True 722 features = f["features"] 723 if features.shape != shape or features.chunks != chunks: 724 raise RuntimeError("Invalid partial features") 725 else: 726 partial_features = False 727 features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32") 728 729 # Initialize the pbar. 730 pbar_init(input_.shape[0], "Compute Image Embeddings 3D") 731 732 # Compute the embeddings for each slice. 733 for z, z_slice in enumerate(input_): 734 # Skip feature computation in case of partial features in non-zero slice. 735 if partial_features and np.count_nonzero(features[z]) != 0: 736 continue 737 738 predictor.reset_image() 739 predictor.set_image(_to_image(z_slice)) 740 embedding = predictor.get_image_embedding() 741 original_size, input_size = predictor.original_size, predictor.input_size 742 743 if save_features: 744 features[z] = embedding.cpu().numpy() 745 else: 746 features.append(embedding[None]) 747 pbar_update(1) 748 749 if save_features: 750 _write_embedding_signature( 751 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 752 ) 753 else: 754 # Concatenate across the z axis. 755 features = torch.cat(features).cpu().numpy() 756 757 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 758 return image_embeddings 759 760 761def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update): 762 # Check if the features are already computed. 763 if "input_size" in f.attrs: 764 features = f["features"] 765 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 766 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 767 return image_embeddings 768 769 # Otherwise compute them. Note: saving happens automatically because we 770 # always write the features to zarr. If no save path is given we use an in-memory zarr. 771 features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update) 772 image_embeddings = {"features": features, "input_size": None, "original_size": None} 773 return image_embeddings 774 775 776def _compute_data_signature(input_): 777 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 778 return data_signature 779 780 781# Create all metadata that is stored along with the embeddings. 782def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 783 if data_signature is None: 784 data_signature = _compute_data_signature(input_) 785 786 signature = { 787 "data_signature": data_signature, 788 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 789 "halo": halo if halo is None else list(halo), 790 "model_type": predictor.model_type, 791 "model_name": predictor.model_name, 792 "micro_sam_version": __version__, 793 "model_hash": getattr(predictor, "_hash", None), 794 } 795 return signature 796 797 798# Note: the input size and orginal size are different if embeddings are tiled or not. 799# That's why we do not include them in the main signature that is being checked 800# (_get_embedding_signature), but just add it for serialization here. 801def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 802 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 803 signature.update({"input_size": input_size, "original_size": original_size}) 804 for key, val in signature.items(): 805 f.attrs[key] = val 806 807 808def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 809 # We may have an empty zarr file that was already created to save the embeddings in. 810 # In this case the embeddings will be computed and we don't need to perform any checks. 811 if "input_size" not in f.attrs: 812 return 813 814 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 815 for key, val in signature.items(): 816 # Check whether the key is missing from the attrs or if the value is not matching. 817 if key not in f.attrs or f.attrs[key] != val: 818 # These keys were recently added, so we don't want to fail yet if they don't 819 # match in order to not invalidate previous embedding files. 820 # Instead we just raise a warning. (For the version we probably also don't want to fail 821 # i the future since it should not invalidate the embeddings). 822 if key in ("micro_sam_version", "model_hash", "model_name"): 823 warnings.warn( 824 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 825 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 826 "But please recompute them if model predictions don't look as expected." 827 ) 828 else: 829 raise RuntimeError( 830 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 831 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 832 ) 833 834 835# Helper function for optional external progress bars. 836def handle_pbar(verbose, pbar_init, pbar_update): 837 """@private""" 838 839 # Noop to provide dummy functions. 840 def noop(*args): 841 pass 842 843 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 844 assert pbar_update is None # avoid inconsistent state of callbacks 845 846 # Create our own progress bar and callbacks 847 pbar = tqdm() 848 849 def pbar_init(total, description): 850 pbar.total = total 851 pbar.set_description(description) 852 853 def pbar_update(update): 854 pbar.update(update) 855 856 def pbar_close(): 857 pbar.close() 858 859 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 860 assert pbar_update is not None 861 pbar = None 862 pbar_close = noop 863 864 else: # we are not verbose, do nothing 865 pbar = None 866 pbar_init, pbar_update, pbar_close = noop, noop, noop 867 868 return pbar, pbar_init, pbar_update, pbar_close 869 870 871def precompute_image_embeddings( 872 predictor: SamPredictor, 873 input_: np.ndarray, 874 save_path: Optional[Union[str, os.PathLike]] = None, 875 lazy_loading: bool = False, 876 ndim: Optional[int] = None, 877 tile_shape: Optional[Tuple[int, int]] = None, 878 halo: Optional[Tuple[int, int]] = None, 879 verbose: bool = True, 880 pbar_init: Optional[callable] = None, 881 pbar_update: Optional[callable] = None, 882) -> ImageEmbeddings: 883 """Compute the image embeddings (output of the encoder) for the input. 884 885 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 886 887 Args: 888 predictor: The SegmentAnything predictor. 889 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 890 save_path: Path to save the embeddings in a zarr container. 891 lazy_loading: Whether to load all embeddings into memory or return an 892 object to load them on demand when required. This only has an effect if 'save_path' is given 893 and if the input is 3 dimensional. 894 ndim: The dimensionality of the data. If not given will be deduced from the input data. 895 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 896 halo: Overlap of the tiles for tiled prediction. 897 verbose: Whether to be verbose in the computation. 898 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 899 Can be used together with pbar_update to handle napari progress bar in other thread. 900 To enables using this function within a threadworker. 901 pbar_update: Callback to update an external progress bar. 902 903 Returns: 904 The image embeddings. 905 """ 906 ndim = input_.ndim if ndim is None else ndim 907 908 # Handle the embedding save_path. 909 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 910 if save_path is None: 911 f = zarr.group() 912 913 # We have a save path and it already exists. Embeddings will be loaded from it, 914 # check that the saved embeddings in there match the parameters of the function call. 915 elif os.path.exists(save_path): 916 f = zarr.open(save_path, "a") 917 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 918 919 # We have a save path and it does not exist yet. Create the zarr file to which the 920 # embeddings will then be saved. 921 else: 922 f = zarr.open(save_path, "a") 923 924 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 925 926 if ndim == 2 and tile_shape is None: 927 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 928 elif ndim == 2 and tile_shape is not None: 929 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 930 elif ndim == 3 and tile_shape is None: 931 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update) 932 elif ndim == 3 and tile_shape is not None: 933 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 934 else: 935 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 936 937 pbar_close() 938 return embeddings 939 940 941def set_precomputed( 942 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 943) -> SamPredictor: 944 """Set the precomputed image embeddings for a predictor. 945 946 Args: 947 predictor: The SegmentAnything predictor. 948 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 949 i: Index for the image data. Required if `image` has three spatial dimensions 950 or a time dimension and two spatial dimensions. 951 tile_id: Index for the tile. This is required if the embeddings are tiled. 952 953 Returns: 954 The predictor with set features. 955 """ 956 if tile_id is not None: 957 tile_features = image_embeddings["features"][tile_id] 958 tile_image_embeddings = { 959 "features": tile_features, 960 "input_size": tile_features.attrs["input_size"], 961 "original_size": tile_features.attrs["original_size"] 962 } 963 return set_precomputed(predictor, tile_image_embeddings, i=i) 964 965 device = predictor.device 966 features = image_embeddings["features"] 967 assert features.ndim in (4, 5), f"{features.ndim}" 968 if features.ndim == 5 and i is None: 969 raise ValueError("The data is 3D so an index i is needed.") 970 elif features.ndim == 4 and i is not None: 971 raise ValueError("The data is 2D so an index is not needed.") 972 973 if i is None: 974 predictor.features = features.to(device) if torch.is_tensor(features) else \ 975 torch.from_numpy(features[:]).to(device) 976 else: 977 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 978 torch.from_numpy(features[i]).to(device) 979 980 predictor.original_size = image_embeddings["original_size"] 981 predictor.input_size = image_embeddings["input_size"] 982 predictor.is_image_set = True 983 984 return predictor 985 986 987# 988# Misc functionality 989# 990 991 992def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 993 """Compute the intersection over union of two masks. 994 995 Args: 996 mask1: The first mask. 997 mask2: The second mask. 998 999 Returns: 1000 The intersection over union of the two masks. 1001 """ 1002 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1003 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1004 eps = 1e-7 1005 iou = float(overlap) / (float(union) + eps) 1006 return iou 1007 1008 1009def get_centers_and_bounding_boxes( 1010 segmentation: np.ndarray, mode: str = "v" 1011) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1012 """Returns the center coordinates of the foreground instances in the ground-truth. 1013 1014 Args: 1015 segmentation: The segmentation. 1016 mode: Determines the functionality used for computing the centers. 1017 If 'v', the object's eccentricity centers computed by vigra are used. 1018 If 'p' the object's centroids computed by skimage are used. 1019 1020 Returns: 1021 A dictionary that maps object ids to the corresponding centroid. 1022 A dictionary that maps object_ids to the corresponding bounding box. 1023 """ 1024 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1025 1026 properties = regionprops(segmentation) 1027 1028 if mode == "p": 1029 center_coordinates = {prop.label: prop.centroid for prop in properties} 1030 elif mode == "v": 1031 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1032 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1033 1034 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1035 1036 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1037 return center_coordinates, bbox_coordinates 1038 1039 1040def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1041 """Helper function to load image data from file. 1042 1043 Args: 1044 path: The filepath to the image data. 1045 key: The internal filepath for complex data formats like hdf5. 1046 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1047 1048 Returns: 1049 The image data. 1050 """ 1051 if key is None: 1052 image_data = imageio.imread(path) 1053 else: 1054 with open_file(path, mode="r") as f: 1055 image_data = f[key] 1056 if not lazy_loading: 1057 image_data = image_data[:] 1058 1059 return image_data 1060 1061 1062def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1063 """Convert the segmentation to one-hot encoded masks. 1064 1065 Args: 1066 segmentation: The segmentation. 1067 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1068 1069 Returns: 1070 The one-hot encoded masks. 1071 """ 1072 masks = segmentation.copy() 1073 if segmentation_ids is None: 1074 n_ids = int(segmentation.max()) 1075 1076 else: 1077 msg = "No foreground objects were found." 1078 if len(segmentation_ids) == 0: # The list should not be completely empty. 1079 raise RuntimeError(msg) 1080 1081 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1082 raise RuntimeError(msg) 1083 1084 # the segmentation ids have to be sorted 1085 segmentation_ids = np.sort(segmentation_ids) 1086 1087 # set the non selected objects to zero and relabel sequentially 1088 masks[~np.isin(masks, segmentation_ids)] = 0 1089 masks = relabel_sequential(masks)[0] 1090 n_ids = len(segmentation_ids) 1091 1092 masks = torch.from_numpy(masks) 1093 1094 one_hot_shape = (n_ids + 1,) + masks.shape 1095 masks = masks.unsqueeze(0) # add dimension to scatter 1096 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1097 1098 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1099 masks = masks.unsqueeze(1) 1100 return masks 1101 1102 1103def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1104 """Get a suitable block shape for chunking a given shape. 1105 1106 The primary use for this is determining chunk sizes for 1107 zarr arrays or block shapes for parallelization. 1108 1109 Args: 1110 shape: The image or volume shape. 1111 1112 Returns: 1113 The block shape. 1114 """ 1115 ndim = len(shape) 1116 if ndim == 2: 1117 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1118 elif ndim == 3: 1119 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1120 else: 1121 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1122 1123 return block_shape
60def get_cache_directory() -> None: 61 """Get micro-sam cache directory location. 62 63 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 64 """ 65 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 66 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 67 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
75def microsam_cachedir() -> None: 76 """Return the micro-sam cache directory. 77 78 Returns the top level cache directory for micro-sam models and sample data. 79 80 Every time this function is called, we check for any user updates made to 81 the MICROSAM_CACHEDIR os environment variable since the last time. 82 """ 83 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 84 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
87def models(): 88 """Return the segmentation models registry. 89 90 We recreate the model registry every time this function is called, 91 so any user changes to the default micro-sam cache directory location 92 are respected. 93 """ 94 95 # We use xxhash to compute the hash of the models, see 96 # https://github.com/computational-cell-analytics/micro-sam/issues/283 97 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 98 # To generate the xxh128 hash: 99 # xxh128sum filename 100 encoder_registry = { 101 # The default segment anything models: 102 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 103 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 104 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 105 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 106 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 107 # The current version of our models in the modelzoo. 108 # LM generalist models: 109 "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2", 110 "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f", 111 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 112 # EM models: 113 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 114 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 115 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 116 # Histopathology models: 117 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 118 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 119 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 120 # Medical Imaging models: 121 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 122 } 123 # Additional decoders for instance segmentation. 124 decoder_registry = { 125 # LM generalist models: 126 "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f", 127 "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d", 128 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 129 # EM models: 130 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 131 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 132 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 133 # Histopathology models: 134 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 135 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 136 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 137 } 138 registry = {**encoder_registry, **decoder_registry} 139 140 encoder_urls = { 141 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 142 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 143 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 144 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 145 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt", 146 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt", 147 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 148 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 149 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 150 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 151 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 152 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 153 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 154 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 155 } 156 157 decoder_urls = { 158 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt", # noqa 159 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt", # noqa 160 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 161 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 162 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 163 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 164 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 165 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 166 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 167 } 168 urls = {**encoder_urls, **decoder_urls} 169 170 models = pooch.create( 171 path=os.path.join(microsam_cachedir(), "models"), 172 base_url="", 173 registry=registry, 174 urls=urls, 175 ) 176 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 199 """Get the torch device. 200 201 If no device is passed the default device for your system is used. 202 Else it will be checked if the device you have passed is supported. 203 204 Args: 205 device: The input device. 206 207 Returns: 208 The device. 209 """ 210 if device is None or device == "auto": 211 device = _get_default_device() 212 else: 213 device_type = device if isinstance(device, str) else device.type 214 if device_type.lower() == "cuda": 215 if not torch.cuda.is_available(): 216 raise RuntimeError("PyTorch CUDA backend is not available.") 217 elif device_type.lower() == "mps": 218 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 219 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 220 elif device_type.lower() == "cpu": 221 pass # cpu is always available 222 else: 223 raise RuntimeError( 224 f"Unsupported device: {device}\n" 225 "Please choose from 'cpu', 'cuda', or 'mps'." 226 ) 227 228 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.
290def get_sam_model( 291 model_type: str = _DEFAULT_MODEL, 292 device: Optional[Union[str, torch.device]] = None, 293 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 294 return_sam: bool = False, 295 return_state: bool = False, 296 peft_kwargs: Optional[Dict] = None, 297 flexible_load_checkpoint: bool = False, 298 **model_kwargs, 299) -> SamPredictor: 300 r"""Get the SegmentAnything Predictor. 301 302 This function will download the required model or load it from the cached weight file. 303 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 304 The name of the requested model can be set via `model_type`. 305 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 306 for an overview of the available models 307 308 Alternatively this function can also load a model from weights stored in a local filepath. 309 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 310 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 311 a SAM model with vit_b encoder. 312 313 By default the models are downloaded to a folder named 'micro_sam/models' 314 inside your default cache directory, eg: 315 * Mac: ~/Library/Caches/<AppName> 316 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 317 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 318 See the pooch.os_cache() documentation for more details: 319 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 320 321 Args: 322 model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. 323 To get a list of all available model names you can call `get_model_names`. 324 device: The device for the model. If none is given will use GPU if available. 325 checkpoint_path: The path to a file with weights that should be used instead of using the 326 weights corresponding to `model_type`. If given, `model_type` must match the architecture 327 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 328 then `model_type` must be given as "vit_b". 329 return_sam: Return the sam model object as well as the predictor. 330 return_state: Return the unpickled checkpoint state. 331 peft_kwargs: Keyword arguments for th PEFT wrapper class. 332 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 333 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 334 335 Returns: 336 The segment anything predictor. 337 """ 338 device = get_device(device) 339 340 # We support passing a local filepath to a checkpoint. 341 # In this case we do not download any weights but just use the local weight file, 342 # as it is, without copying it over anywhere or checking it's hashes. 343 344 # checkpoint_path has not been passed, we download a known model and derive the correct 345 # URL from the model_type. If the model_type is invalid pooch will raise an error. 346 if checkpoint_path is None: 347 model_registry = models() 348 checkpoint_path = model_registry.fetch(model_type, progressbar=True) 349 model_hash = model_registry.registry[model_type] 350 351 # If we have a custom model then we may also have a decoder checkpoint. 352 # Download it here, so that we can add it to the state. 353 decoder_name = f"{model_type}_decoder" 354 decoder_path = model_registry.fetch( 355 decoder_name, progressbar=True 356 ) if decoder_name in model_registry.registry else None 357 358 # checkpoint_path has been passed, we use it instead of downloading a model. 359 else: 360 # Check if the file exists and raise an error otherwise. 361 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 362 # (If it isn't the model creation will fail below.) 363 if not os.path.exists(checkpoint_path): 364 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 365 model_hash = _compute_hash(checkpoint_path) 366 decoder_path = None 367 368 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 369 # before calling sam_model_registry. 370 abbreviated_model_type = model_type[:5] 371 if abbreviated_model_type not in _MODEL_TYPES: 372 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 373 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 374 raise RuntimeError( 375 "'mobile_sam' is required for the vit-tiny. " 376 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 377 ) 378 379 state, model_state = _load_checkpoint(checkpoint_path) 380 381 # Whether to update parameters necessary to initialize the model 382 if model_kwargs: # Checks whether model_kwargs have been provided or not 383 if abbreviated_model_type == "vit_t": 384 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 385 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 386 387 else: 388 sam = sam_model_registry[abbreviated_model_type]() 389 390 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 391 # Overwrites the SAM model by freezing the backbone and allow PEFT. 392 if peft_kwargs and isinstance(peft_kwargs, dict): 393 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 394 peft_kwargs.pop("quantize", None) 395 396 if abbreviated_model_type == "vit_t": 397 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 398 399 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 400 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 401 if flexible_load_checkpoint: 402 sam = _handle_checkpoint_loading(sam, model_state) 403 else: 404 sam.load_state_dict(model_state) 405 sam.to(device=device) 406 407 predictor = SamPredictor(sam) 408 predictor.model_type = abbreviated_model_type 409 predictor._hash = model_hash 410 predictor.model_name = model_type 411 412 # Add the decoder to the state if we have one and if the state is returned. 413 if decoder_path is not None and return_state: 414 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 415 416 if return_sam and return_state: 417 return predictor, sam, state 418 if return_sam: 419 return predictor, sam 420 if return_state: 421 return predictor, state 422 return predictor
Get the SegmentAnything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type
.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path
. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The Segment Anything model to use. Will use the standard
vit_l
model by default. To get a list of all available model names you can callget_model_names
. - device: The device for the model. If none is given will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type
. If given,model_type
must match the architecture corresponding to the weight file. e.g. if you use weights for SAM withvit_b
encoder thenmodel_type
must be given as "vit_b". - return_sam: Return the sam model object as well as the predictor.
- return_state: Return the unpickled checkpoint state.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
- model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:
The segment anything predictor.
458def export_custom_sam_model( 459 checkpoint_path: Union[str, os.PathLike], 460 model_type: str, 461 save_path: Union[str, os.PathLike], 462 with_segmentation_decoder: bool = False, 463) -> None: 464 """Export a finetuned Segment Anything Model to the standard model format. 465 466 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 467 468 Args: 469 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 470 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 471 save_path: Where to save the exported model. 472 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 473 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 474 """ 475 _, state = get_sam_model( 476 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 477 ) 478 model_state = state["model_state"] 479 prefix = "sam." 480 model_state = OrderedDict( 481 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 482 ) 483 484 # Store the 'decoder_state' as well, if desired. 485 if with_segmentation_decoder: 486 if "decoder_state" not in state: 487 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 488 decoder_state = state["decoder_state"] 489 save_state = {"model_state": model_state, "decoder_state": decoder_state} 490 else: 491 save_state = model_state 492 493 torch.save(save_state, save_path)
Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator
.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
- with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
496def export_custom_qlora_model( 497 checkpoint_path: Union[str, os.PathLike], 498 finetuned_path: Union[str, os.PathLike], 499 model_type: str, 500 save_path: Union[str, os.PathLike], 501) -> None: 502 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 503 504 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 505 506 Args: 507 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 508 finetuned_path: The path to the new finetuned model, using QLoRA. 509 model_type: The Segment Anything Model type corresponding to the checkpoint. 510 save_path: Where to save the exported model. 511 """ 512 # Step 1: Get the base SAM model: used to start finetuning from. 513 _, sam = get_sam_model( 514 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 515 ) 516 517 # Step 2: Load the QLoRA-style finetuned model. 518 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 519 520 # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model. 521 updated_model_state = {} 522 523 # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint. 524 for k, v in ft_model_state.items(): 525 if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1: 526 updated_model_state[k] = v 527 528 # - Next, we get all the remaining parameters from the base SAM model. 529 for k, v in sam.state_dict().items(): 530 if k.find("attn.qkv.") != -1: 531 k = k.replace("qkv", "qkv.qkv_proj") 532 updated_model_state[k] = v 533 else: 534 535 updated_model_state[k] = v 536 537 # - Finally, we replace the old model state with the new one (to retain other relevant stuff) 538 ft_state['model_state'] = updated_model_state 539 540 # Step 4: Store the new "state" to "save_path" 541 torch.save(ft_state, save_path)
Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs
to get_sam_model
.
Arguments:
- checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
- finetuned_path: The path to the new finetuned model, using QLoRA.
- model_type: The Segment Anything Model type corresponding to the checkpoint.
- save_path: Where to save the exported model.
872def precompute_image_embeddings( 873 predictor: SamPredictor, 874 input_: np.ndarray, 875 save_path: Optional[Union[str, os.PathLike]] = None, 876 lazy_loading: bool = False, 877 ndim: Optional[int] = None, 878 tile_shape: Optional[Tuple[int, int]] = None, 879 halo: Optional[Tuple[int, int]] = None, 880 verbose: bool = True, 881 pbar_init: Optional[callable] = None, 882 pbar_update: Optional[callable] = None, 883) -> ImageEmbeddings: 884 """Compute the image embeddings (output of the encoder) for the input. 885 886 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 887 888 Args: 889 predictor: The SegmentAnything predictor. 890 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 891 save_path: Path to save the embeddings in a zarr container. 892 lazy_loading: Whether to load all embeddings into memory or return an 893 object to load them on demand when required. This only has an effect if 'save_path' is given 894 and if the input is 3 dimensional. 895 ndim: The dimensionality of the data. If not given will be deduced from the input data. 896 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 897 halo: Overlap of the tiles for tiled prediction. 898 verbose: Whether to be verbose in the computation. 899 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 900 Can be used together with pbar_update to handle napari progress bar in other thread. 901 To enables using this function within a threadworker. 902 pbar_update: Callback to update an external progress bar. 903 904 Returns: 905 The image embeddings. 906 """ 907 ndim = input_.ndim if ndim is None else ndim 908 909 # Handle the embedding save_path. 910 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 911 if save_path is None: 912 f = zarr.group() 913 914 # We have a save path and it already exists. Embeddings will be loaded from it, 915 # check that the saved embeddings in there match the parameters of the function call. 916 elif os.path.exists(save_path): 917 f = zarr.open(save_path, "a") 918 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 919 920 # We have a save path and it does not exist yet. Create the zarr file to which the 921 # embeddings will then be saved. 922 else: 923 f = zarr.open(save_path, "a") 924 925 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 926 927 if ndim == 2 and tile_shape is None: 928 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 929 elif ndim == 2 and tile_shape is not None: 930 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 931 elif ndim == 3 and tile_shape is None: 932 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update) 933 elif ndim == 3 and tile_shape is not None: 934 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 935 else: 936 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 937 938 pbar_close() 939 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The SegmentAnything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
- tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Whether to be verbose in the computation.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
942def set_precomputed( 943 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 944) -> SamPredictor: 945 """Set the precomputed image embeddings for a predictor. 946 947 Args: 948 predictor: The SegmentAnything predictor. 949 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 950 i: Index for the image data. Required if `image` has three spatial dimensions 951 or a time dimension and two spatial dimensions. 952 tile_id: Index for the tile. This is required if the embeddings are tiled. 953 954 Returns: 955 The predictor with set features. 956 """ 957 if tile_id is not None: 958 tile_features = image_embeddings["features"][tile_id] 959 tile_image_embeddings = { 960 "features": tile_features, 961 "input_size": tile_features.attrs["input_size"], 962 "original_size": tile_features.attrs["original_size"] 963 } 964 return set_precomputed(predictor, tile_image_embeddings, i=i) 965 966 device = predictor.device 967 features = image_embeddings["features"] 968 assert features.ndim in (4, 5), f"{features.ndim}" 969 if features.ndim == 5 and i is None: 970 raise ValueError("The data is 3D so an index i is needed.") 971 elif features.ndim == 4 and i is not None: 972 raise ValueError("The data is 2D so an index is not needed.") 973 974 if i is None: 975 predictor.features = features.to(device) if torch.is_tensor(features) else \ 976 torch.from_numpy(features[:]).to(device) 977 else: 978 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 979 torch.from_numpy(features[i]).to(device) 980 981 predictor.original_size = image_embeddings["original_size"] 982 predictor.input_size = image_embeddings["input_size"] 983 predictor.is_image_set = True 984 985 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The SegmentAnything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings
. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
993def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 994 """Compute the intersection over union of two masks. 995 996 Args: 997 mask1: The first mask. 998 mask2: The second mask. 999 1000 Returns: 1001 The intersection over union of the two masks. 1002 """ 1003 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1004 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1005 eps = 1e-7 1006 iou = float(overlap) / (float(union) + eps) 1007 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
1010def get_centers_and_bounding_boxes( 1011 segmentation: np.ndarray, mode: str = "v" 1012) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1013 """Returns the center coordinates of the foreground instances in the ground-truth. 1014 1015 Args: 1016 segmentation: The segmentation. 1017 mode: Determines the functionality used for computing the centers. 1018 If 'v', the object's eccentricity centers computed by vigra are used. 1019 If 'p' the object's centroids computed by skimage are used. 1020 1021 Returns: 1022 A dictionary that maps object ids to the corresponding centroid. 1023 A dictionary that maps object_ids to the corresponding bounding box. 1024 """ 1025 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1026 1027 properties = regionprops(segmentation) 1028 1029 if mode == "p": 1030 center_coordinates = {prop.label: prop.centroid for prop in properties} 1031 elif mode == "v": 1032 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1033 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1034 1035 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1036 1037 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1038 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
1041def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1042 """Helper function to load image data from file. 1043 1044 Args: 1045 path: The filepath to the image data. 1046 key: The internal filepath for complex data formats like hdf5. 1047 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1048 1049 Returns: 1050 The image data. 1051 """ 1052 if key is None: 1053 image_data = imageio.imread(path) 1054 else: 1055 with open_file(path, mode="r") as f: 1056 image_data = f[key] 1057 if not lazy_loading: 1058 image_data = image_data[:] 1059 1060 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
1063def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1064 """Convert the segmentation to one-hot encoded masks. 1065 1066 Args: 1067 segmentation: The segmentation. 1068 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1069 1070 Returns: 1071 The one-hot encoded masks. 1072 """ 1073 masks = segmentation.copy() 1074 if segmentation_ids is None: 1075 n_ids = int(segmentation.max()) 1076 1077 else: 1078 msg = "No foreground objects were found." 1079 if len(segmentation_ids) == 0: # The list should not be completely empty. 1080 raise RuntimeError(msg) 1081 1082 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1083 raise RuntimeError(msg) 1084 1085 # the segmentation ids have to be sorted 1086 segmentation_ids = np.sort(segmentation_ids) 1087 1088 # set the non selected objects to zero and relabel sequentially 1089 masks[~np.isin(masks, segmentation_ids)] = 0 1090 masks = relabel_sequential(masks)[0] 1091 n_ids = len(segmentation_ids) 1092 1093 masks = torch.from_numpy(masks) 1094 1095 one_hot_shape = (n_ids + 1,) + masks.shape 1096 masks = masks.unsqueeze(0) # add dimension to scatter 1097 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1098 1099 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1100 masks = masks.unsqueeze(1) 1101 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:
The one-hot encoded masks.
1104def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1105 """Get a suitable block shape for chunking a given shape. 1106 1107 The primary use for this is determining chunk sizes for 1108 zarr arrays or block shapes for parallelization. 1109 1110 Args: 1111 shape: The image or volume shape. 1112 1113 Returns: 1114 The block shape. 1115 """ 1116 ndim = len(shape) 1117 if ndim == 2: 1118 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1119 elif ndim == 3: 1120 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1121 else: 1122 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1123 1124 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.