micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5import os 6import multiprocessing as mp 7import pickle 8import hashlib 9import warnings 10from concurrent import futures 11from pathlib import Path 12from collections import OrderedDict 13from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Callable 14 15import elf.parallel as parallel_impl 16import imageio.v3 as imageio 17import numpy as np 18import pooch 19import segment_anything.utils.amg as amg_utils 20import torch 21import vigra 22import xxhash 23import zarr 24 25from elf.io import open_file 26from nifty.tools import blocking 27from skimage.measure import regionprops 28from skimage.segmentation import relabel_sequential 29from torchvision.ops.boxes import batched_nms 30 31from .__version__ import __version__ 32from . import models as custom_models 33 34try: 35 # Avoid import warnigns from mobile_sam 36 with warnings.catch_warnings(): 37 warnings.simplefilter("ignore") 38 from mobile_sam import sam_model_registry, SamPredictor 39 VIT_T_SUPPORT = True 40except ImportError: 41 from segment_anything import sam_model_registry, SamPredictor 42 VIT_T_SUPPORT = False 43 44try: 45 from napari.utils import progress as tqdm 46except ImportError: 47 from tqdm import tqdm 48 49# This is the default model used in micro_sam 50# Currently it is set to vit_b_lm 51_DEFAULT_MODEL = "vit_b_lm" 52 53# The valid model types. Each type corresponds to the architecture of the 54# vision transformer used within SAM. 55_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 56 57 58ImageEmbeddings = Dict[str, Any] 59"""@private""" 60 61 62def get_cache_directory() -> None: 63 """Get micro-sam cache directory location. 64 65 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 66 """ 67 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 68 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 69 return cache_directory 70 71 72# 73# Functionality for model download and export 74# 75 76 77def microsam_cachedir() -> None: 78 """Return the micro-sam cache directory. 79 80 Returns the top level cache directory for micro-sam models and sample data. 81 82 Every time this function is called, we check for any user updates made to 83 the MICROSAM_CACHEDIR os environment variable since the last time. 84 """ 85 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 86 return cache_directory 87 88 89def models(): 90 """Return the segmentation models registry. 91 92 We recreate the model registry every time this function is called, 93 so any user changes to the default micro-sam cache directory location 94 are respected. 95 """ 96 97 # We use xxhash to compute the hash of the models, see 98 # https://github.com/computational-cell-analytics/micro-sam/issues/283 99 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 100 # To generate the xxh128 hash: 101 # xxh128sum filename 102 encoder_registry = { 103 # The default segment anything models: 104 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 105 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 106 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 107 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 108 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 109 # The current version of our models in the modelzoo. 110 # LM generalist models: 111 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 112 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 113 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 114 # EM models: 115 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 116 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 117 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 118 # Histopathology models: 119 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 120 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 121 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 122 # Medical Imaging models: 123 "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92", 124 } 125 # Additional decoders for instance segmentation. 126 decoder_registry = { 127 # LM generalist models: 128 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 129 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 130 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 131 # EM models: 132 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 133 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 134 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 135 # Histopathology models: 136 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 137 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 138 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 139 # Medical Imaging models: 140 "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed", 141 } 142 registry = {**encoder_registry, **decoder_registry} 143 144 encoder_urls = { 145 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 146 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 147 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 148 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 149 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 150 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 151 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 152 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 153 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 154 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 155 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 156 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 157 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 158 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download", 159 } 160 161 decoder_urls = { 162 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 163 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 164 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 165 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 166 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 167 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 168 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 169 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 170 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 171 "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download", 172 } 173 urls = {**encoder_urls, **decoder_urls} 174 175 models = pooch.create( 176 path=os.path.join(microsam_cachedir(), "models"), 177 base_url="", 178 registry=registry, 179 urls=urls, 180 ) 181 return models 182 183 184def _get_default_device(): 185 # check that we're in CI and use the CPU if we are 186 # otherwise the tests may run out of memory on MAC if MPS is used. 187 if os.getenv("GITHUB_ACTIONS") == "true": 188 return "cpu" 189 # Use cuda enabled gpu if it's available. 190 if torch.cuda.is_available(): 191 device = "cuda" 192 # As second priority use mps. 193 # See https://pytorch.org/docs/stable/notes/mps.html for details 194 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 195 print("Using apple MPS device.") 196 device = "mps" 197 # Use the CPU as fallback. 198 else: 199 device = "cpu" 200 return device 201 202 203def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 204 """Get the torch device. 205 206 If no device is passed the default device for your system is used. 207 Else it will be checked if the device you have passed is supported. 208 209 Args: 210 device: The input device. By default, selects the best available device supports. 211 212 Returns: 213 The device. 214 """ 215 if device is None or device == "auto": 216 device = _get_default_device() 217 else: 218 device_type = device if isinstance(device, str) else device.type 219 if device_type.lower() == "cuda": 220 if not torch.cuda.is_available(): 221 raise RuntimeError("PyTorch CUDA backend is not available.") 222 elif device_type.lower() == "mps": 223 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 224 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 225 elif device_type.lower() == "cpu": 226 pass # cpu is always available 227 else: 228 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 229 230 return device 231 232 233def _available_devices(): 234 available_devices = [] 235 for i in ["cuda", "mps", "cpu"]: 236 try: 237 device = get_device(i) 238 except RuntimeError: 239 pass 240 else: 241 available_devices.append(device) 242 return available_devices 243 244 245# We write a custom unpickler that skips objects that cannot be found instead of 246# throwing an AttributeError or ModueNotFoundError. 247# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 248# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 249class _CustomUnpickler(pickle.Unpickler): 250 def find_class(self, module, name): 251 try: 252 return super().find_class(module, name) 253 except (AttributeError, ModuleNotFoundError) as e: 254 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 255 return None 256 257 258def _compute_hash(path, chunk_size=8192): 259 hash_obj = xxhash.xxh128() 260 with open(path, "rb") as f: 261 chunk = f.read(chunk_size) 262 while chunk: 263 hash_obj.update(chunk) 264 chunk = f.read(chunk_size) 265 hash_val = hash_obj.hexdigest() 266 return f"xxh128:{hash_val}" 267 268 269# Load the state from a checkpoint. 270# The checkpoint can either contain a sam encoder state 271# or it can be a checkpoint for model finetuning. 272def _load_checkpoint(checkpoint_path): 273 # Over-ride the unpickler with our custom one. 274 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 275 custom_pickle = pickle 276 custom_pickle.Unpickler = _CustomUnpickler 277 278 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 279 if "model_state" in state: 280 # Copy the model weights from torch_em's training format. 281 model_state = state["model_state"] 282 sam_prefix = "sam." 283 model_state = OrderedDict( 284 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 285 ) 286 else: 287 model_state = state 288 289 return state, model_state 290 291 292def _download_sam_model(model_type, progress_bar_factory=None): 293 model_registry = models() 294 295 progress_bar = True 296 # Check if we have to download the model. 297 # If we do and have a progress bar factory, then we over-write the progress bar. 298 if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None: 299 progress_bar = progress_bar_factory(model_type) 300 301 checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar) 302 if not isinstance(progress_bar, bool): # Close the progress bar when the task finishes. 303 progress_bar.close() 304 305 model_hash = model_registry.registry[model_type] 306 307 # If we have a custom model then we may also have a decoder checkpoint. 308 # Download it here, so that we can add it to the state. 309 decoder_name = f"{model_type}_decoder" 310 decoder_path = model_registry.fetch( 311 decoder_name, progressbar=True 312 ) if decoder_name in model_registry.registry else None 313 314 return checkpoint_path, model_hash, decoder_path 315 316 317def get_sam_model( 318 model_type: str = _DEFAULT_MODEL, 319 device: Optional[Union[str, torch.device]] = None, 320 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 321 return_sam: bool = False, 322 return_state: bool = False, 323 peft_kwargs: Optional[Dict] = None, 324 flexible_load_checkpoint: bool = False, 325 progress_bar_factory: Optional[Callable] = None, 326 **model_kwargs, 327) -> SamPredictor: 328 r"""Get the Segment Anything Predictor. 329 330 This function will download the required model or load it from the cached weight file. 331 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 332 The name of the requested model can be set via `model_type`. 333 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 334 for an overview of the available models 335 336 Alternatively this function can also load a model from weights stored in a local filepath. 337 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 338 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 339 a SAM model with vit_b encoder. 340 341 By default the models are downloaded to a folder named 'micro_sam/models' 342 inside your default cache directory, eg: 343 * Mac: ~/Library/Caches/<AppName> 344 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 345 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 346 See the pooch.os_cache() documentation for more details: 347 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 348 349 Args: 350 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 351 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 352 device: The device for the model. If 'None' is provided, will use GPU if available. 353 checkpoint_path: The path to a file with weights that should be used instead of using the 354 weights corresponding to `model_type`. If given, `model_type` must match the architecture 355 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 356 then `model_type` must be given as 'vit_b'. 357 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 358 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 359 peft_kwargs: Keyword arguments for th PEFT wrapper class. 360 If passed 'None', it does not initialize any parameter efficient finetuning. 361 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 362 By default, set to 'False'. 363 progress_bar_factory: A function to create a progress bar for the model download. 364 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 365 366 Returns: 367 The Segment Anything predictor. 368 """ 369 device = get_device(device) 370 371 # We support passing a local filepath to a checkpoint. 372 # In this case we do not download any weights but just use the local weight file, 373 # as it is, without copying it over anywhere or checking it's hashes. 374 375 # checkpoint_path has not been passed, we download a known model and derive the correct 376 # URL from the model_type. If the model_type is invalid pooch will raise an error. 377 _provided_checkpoint_path = checkpoint_path is not None 378 if checkpoint_path is None: 379 checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory) 380 381 # checkpoint_path has been passed, we use it instead of downloading a model. 382 else: 383 # Check if the file exists and raise an error otherwise. 384 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 385 # (If it isn't the model creation will fail below.) 386 if not os.path.exists(checkpoint_path): 387 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 388 model_hash = _compute_hash(checkpoint_path) 389 decoder_path = None 390 391 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 392 # before calling sam_model_registry. 393 abbreviated_model_type = model_type[:5] 394 if abbreviated_model_type not in _MODEL_TYPES: 395 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 396 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 397 raise RuntimeError( 398 "'mobile_sam' is required for the vit-tiny. " 399 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 400 ) 401 402 state, model_state = _load_checkpoint(checkpoint_path) 403 404 if _provided_checkpoint_path: 405 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 406 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 407 from micro_sam.models.build_sam import _validate_model_type 408 _provided_model_type = _validate_model_type(model_state) 409 410 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 411 # Otherwise replace 'abbreviated_model_type' with the later. 412 if abbreviated_model_type != _provided_model_type: 413 # Printing the message below to avoid any filtering of warnings on user's end. 414 print( 415 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 416 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 417 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 418 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 419 ) 420 421 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 422 abbreviated_model_type = _provided_model_type 423 424 # Whether to update parameters necessary to initialize the model 425 if model_kwargs: # Checks whether model_kwargs have been provided or not 426 if abbreviated_model_type == "vit_t": 427 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 428 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 429 430 else: 431 sam = sam_model_registry[abbreviated_model_type]() 432 433 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 434 # Overwrites the SAM model by freezing the backbone and allow PEFT. 435 if peft_kwargs and isinstance(peft_kwargs, dict): 436 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 437 peft_kwargs.pop("quantize", None) 438 439 if abbreviated_model_type == "vit_t": 440 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 441 442 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 443 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 444 if flexible_load_checkpoint: 445 sam = _handle_checkpoint_loading(sam, model_state) 446 else: 447 sam.load_state_dict(model_state) 448 sam.to(device=device) 449 450 predictor = SamPredictor(sam) 451 predictor.model_type = abbreviated_model_type 452 predictor._hash = model_hash 453 predictor.model_name = model_type 454 predictor.checkpoint_path = checkpoint_path 455 456 # Add the decoder to the state if we have one and if the state is returned. 457 if decoder_path is not None and return_state: 458 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 459 460 if return_sam and return_state: 461 return predictor, sam, state 462 if return_sam: 463 return predictor, sam 464 if return_state: 465 return predictor, state 466 return predictor 467 468 469def _handle_checkpoint_loading(sam, model_state): 470 # Whether to handle the mismatch issues in a bit more elegant way. 471 # eg. while training for multi-class semantic segmentation in the mask encoder, 472 # parameters are updated - leading to "size mismatch" errors 473 474 new_state_dict = {} # for loading matching parameters 475 mismatched_layers = [] # for tracking mismatching parameters 476 477 reference_state = sam.state_dict() 478 479 for k, v in model_state.items(): 480 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 481 if reference_state[k].size() == v.size(): 482 new_state_dict[k] = v 483 else: 484 mismatched_layers.append(k) 485 486 reference_state.update(new_state_dict) 487 488 if len(mismatched_layers) > 0: 489 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 490 491 for mlayer in mismatched_layers: 492 if 'weight' in mlayer: 493 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 494 elif 'bias' in mlayer: 495 reference_state[mlayer].zero_() 496 497 sam.load_state_dict(reference_state) 498 499 return sam 500 501 502def export_custom_sam_model( 503 checkpoint_path: Union[str, os.PathLike], 504 model_type: str, 505 save_path: Union[str, os.PathLike], 506 with_segmentation_decoder: bool = False, 507 prefix: str = "sam.", 508) -> None: 509 """Export a finetuned Segment Anything Model to the standard model format. 510 511 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 512 513 Args: 514 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 515 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 516 save_path: Where to save the exported model. 517 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 518 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 519 prefix: The prefix to remove from the model parameter keys. 520 """ 521 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 522 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 523 524 # Store the 'decoder_state' as well, if desired. 525 if with_segmentation_decoder: 526 if "decoder_state" not in state: 527 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 528 decoder_state = state["decoder_state"] 529 save_state = {"model_state": model_state, "decoder_state": decoder_state} 530 else: 531 save_state = model_state 532 533 torch.save(save_state, save_path) 534 535 536def export_custom_qlora_model( 537 checkpoint_path: Optional[Union[str, os.PathLike]], 538 finetuned_path: Union[str, os.PathLike], 539 model_type: str, 540 save_path: Union[str, os.PathLike], 541) -> None: 542 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 543 544 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 545 546 Args: 547 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 548 finetuned_path: The path to the new finetuned model, using QLoRA. 549 model_type: The Segment Anything Model type corresponding to the checkpoint. 550 save_path: Where to save the exported model. 551 """ 552 # Step 1: Get the base SAM model: used to start finetuning from. 553 _, sam = get_sam_model( 554 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 555 ) 556 557 # Step 2: Load the QLoRA-style finetuned model. 558 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 559 560 # Step 3: Identify LoRA layers from QLoRA model. 561 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 562 # - then copy the LoRA layers from the QLoRA model to the new state dict 563 updated_model_state = {} 564 565 modified_attn_layers = set() 566 modified_mlp_layers = set() 567 568 for k, v in ft_model_state.items(): 569 if "blocks." in k: 570 layer_id = int(k.split("blocks.")[1].split(".")[0]) 571 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 572 modified_attn_layers.add(layer_id) 573 updated_model_state[k] = v 574 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 575 modified_mlp_layers.add(layer_id) 576 updated_model_state[k] = v 577 578 # Step 4: Next, we get all the remaining parameters from the base SAM model. 579 for k, v in sam.state_dict().items(): 580 if "blocks." in k: 581 layer_id = int(k.split("blocks.")[1].split(".")[0]) 582 if k.find("attn.qkv.") != -1: 583 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 584 k = k.replace("qkv", "qkv.qkv_proj") 585 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 586 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 587 k = k.replace("mlp.", "mlp.mlp_layer.") 588 updated_model_state[k] = v 589 590 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 591 ft_state['model_state'] = updated_model_state 592 593 # Step 6: Store the new "state" to "save_path" 594 torch.save(ft_state, save_path) 595 596 597def get_model_names() -> Iterable: 598 model_registry = models() 599 model_names = model_registry.registry.keys() 600 return model_names 601 602 603# 604# Functionality for precomputing image embeddings. 605# 606 607 608def _to_image(image): 609 input_ = image 610 ndim = input_.ndim 611 n_channels = 1 if ndim == 2 else input_.shape[-1] 612 613 # Map the input to three channels. 614 if ndim == 2: # Grayscale image -> replicate channels. 615 input_ = np.concatenate([input_[..., None]] * 3, axis=-1) 616 elif ndim == 3 and n_channels == 1: # Grayscale image -> replicate channels. 617 input_ = np.concatenate([input_] * 3, axis=-1) 618 elif ndim == 3 and n_channels == 2: # Two channels -> add a zero channel. 619 zero_channel = np.zeros(input_.shape[:2] + (1,), dtype=input_.dtype) 620 input_ = np.concatenate([input_, zero_channel], axis=-1) 621 elif input_.ndim == 3 and n_channels == 3: # RGB input -> do nothing. 622 pass 623 elif input_.ndim == 3 and n_channels > 3: # More than three channels -> select first three. 624 warnings.warn(f"You provided an input with {n_channels} channels. Only the first three will be used.") 625 input_ = input_[..., :3] 626 else: 627 raise ValueError( 628 f"Invalid input dimensionality {ndim}. Expect either a 2D input (=grayscale image) " 629 "or a 3D input (= image with channels)." 630 ) 631 assert input_.ndim == 3 and input_.shape[-1] == 3 632 633 # Normalize the input per channel and bring it to uint8. 634 input_ = input_.astype("float32") 635 input_ -= input_.min(axis=(0, 1))[None, None] 636 input_ /= (input_.max(axis=(0, 1))[None, None] + 1e-7) 637 input_ = (input_ * 255).astype("uint8") 638 639 # Explicitly return a numpy array for compatibility with torchvision 640 # because the input_ array could be something like dask array. 641 return np.array(input_) 642 643 644@torch.no_grad 645def _compute_embeddings_batched(predictor, batched_images): 646 predictor.reset_image() 647 batched_tensors, original_sizes, input_sizes = [], [], [] 648 649 # Apply proeprocessing to all images in the batch, and then stack them. 650 # Note: after the transformation the images are all of the same size, 651 # so they can be stacked and processed as a batch, even if the input images were of different size. 652 for image in batched_images: 653 tensor = predictor.transform.apply_image(image) 654 tensor = torch.as_tensor(tensor, device=predictor.device) 655 tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :] 656 657 original_sizes.append(image.shape[:2]) 658 input_sizes.append(tensor.shape[-2:]) 659 660 tensor = predictor.model.preprocess(tensor) 661 batched_tensors.append(tensor) 662 663 batched_tensors = torch.cat(batched_tensors) 664 features = predictor.model.image_encoder(batched_tensors) 665 666 predictor.original_size = original_sizes[-1] 667 predictor.input_size = input_sizes[-1] 668 predictor.features = features[-1] 669 predictor.is_image_set = True 670 671 return features, original_sizes, input_sizes 672 673 674# Wrapper of zarr.create dataset to support zarr v2 and zarr v3. 675def _create_dataset_with_data(group, name, data, chunks=None): 676 zarr_major_version = int(zarr.__version__.split(".")[0]) 677 if chunks is None: 678 chunks = data.shape 679 if zarr_major_version == 2: 680 ds = group.create_dataset(name, data=data, shape=data.shape, chunks=chunks) 681 elif zarr_major_version == 3: 682 ds = group.create_array(name, shape=data.shape, chunks=chunks, dtype=data.dtype) 683 ds[:] = data 684 else: 685 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 686 return ds 687 688 689def _create_dataset_without_data(group, name, shape, dtype, chunks): 690 zarr_major_version = int(zarr.__version__.split(".")[0]) 691 if zarr_major_version == 2: 692 ds = group.create_dataset(name, shape=shape, dtype=dtype, chunks=chunks) 693 elif zarr_major_version == 3: 694 ds = group.create_array(name, shape=shape, chunks=chunks, dtype=dtype) 695 else: 696 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 697 return ds 698 699 700def _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=None, n_slices=None): 701 702 # Pre-create / pre-fetch the datasets if we have slices. 703 # (Dataset creation is not thread-safe) 704 if slices is not None: 705 datasets = {} 706 for tile_id, tile_embeddings, original_size, input_size in zip( 707 tile_ids, batched_embeddings, original_sizes, input_sizes 708 ): 709 ds_name = str(tile_id) 710 if ds_name in datasets: 711 continue 712 if ds_name in features: 713 datasets[ds_name] = features[ds_name] 714 continue 715 shape = (n_slices, 1) + tile_embeddings.shape 716 chunks = (1, 1) + tile_embeddings.shape 717 ds = _create_dataset_without_data(features, ds_name, shape=shape, dtype="float32", chunks=chunks) 718 ds.attrs["original_size"] = original_size 719 ds.attrs["input_size"] = input_size 720 datasets[ds_name] = ds 721 722 def _write_embed(i): 723 ds_name = str(tile_ids[i]) 724 tile_embeddings = batched_embeddings[i].unsqueeze(0) 725 if slices is None: 726 ds = _create_dataset_with_data(features, ds_name, data=tile_embeddings.cpu().numpy()) 727 ds.attrs["original_size"] = original_sizes[i] 728 ds.attrs["input_size"] = input_sizes[i] 729 elif ds_name in features: 730 ds = datasets[ds_name] 731 z = slices[i] 732 ds[z] = tile_embeddings.cpu().numpy() 733 734 n_tiles = len(tile_ids) 735 n_workers = min(mp.cpu_count(), n_tiles) 736 with futures.ThreadPoolExecutor(n_workers) as tp: 737 list(tp.map(_write_embed, range(n_tiles))) 738 739 740def _get_tiles_in_mask(mask, tiling, halo, z=None): 741 def _check_mask(tile_id): 742 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 743 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 744 if z is not None: 745 outer_tile = (z,) + outer_tile 746 tile_mask = mask[outer_tile].astype("bool") 747 return None if tile_mask.sum() == 0 else tile_id 748 749 n_threads = mp.cpu_count() 750 with futures.ThreadPoolExecutor(n_threads) as tp: 751 tiles_in_mask = tp.map(_check_mask, range(tiling.numberOfBlocks)) 752 return sorted([tile_id for tile_id in tiles_in_mask if tile_id is not None]) 753 754 755def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 756 tiling = blocking([0, 0], input_.shape[:2], tile_shape) 757 n_tiles = tiling.numberOfBlocks 758 759 features = f.require_group("features") 760 features.attrs["shape"] = input_.shape[:2] 761 features.attrs["tile_shape"] = tile_shape 762 features.attrs["halo"] = halo 763 764 n_batches = int(np.ceil(n_tiles / batch_size)) 765 if mask is None: 766 tile_ids_for_batches = [ 767 range(batch_id * batch_size, min((batch_id + 1) * batch_size, n_tiles)) 768 for batch_id in range(n_batches) 769 ] 770 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled") 771 else: 772 tiles_in_mask = _get_tiles_in_mask(mask, tiling, halo) 773 pbar_init(len(tiles_in_mask), "Compute Image Embeddings 2D tiled with mask") 774 tile_ids_for_batches = np.array_split(tiles_in_mask, n_batches) 775 assert len(tile_ids_for_batches) == n_batches 776 777 for tile_ids in tile_ids_for_batches: 778 batched_images = [] 779 for tile_id in tile_ids: 780 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 781 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 782 tile_input = _to_image(input_[outer_tile]) 783 batched_images.append(tile_input) 784 785 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 786 _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes) 787 pbar_update(len(tile_ids)) 788 789 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 790 if mask is not None: 791 features.attrs["tiles_in_mask"] = tiles_in_mask 792 793 return features 794 795 796class _BatchProvider: 797 def __init__(self, n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size): 798 if tiles_in_mask_per_slice is None: 799 self.n_tiles_total = n_slices * n_tiles_per_plane 800 else: 801 self.n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values()) 802 803 self.n_batches = int(np.ceil(self.n_tiles_total / batch_size)) 804 self.n_slices = n_slices 805 self.n_tiles_per_plane = n_tiles_per_plane 806 self.tiles_in_mask_per_slice = tiles_in_mask_per_slice 807 self.batch_size = batch_size 808 809 # Iter variables. 810 self.batch_id = 0 811 self.z = 0 812 self.tile_id = 0 813 814 def __iter__(self): 815 return self 816 817 def __next__(self): 818 if self.batch_id >= self.n_batches: 819 raise StopIteration 820 821 z_list = list(range(self.n_tiles_per_plane)) 822 z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z] 823 824 slices, tile_ids = [], [] 825 this_batch_size = 0 826 while this_batch_size < self.batch_size: 827 if self.tile_id == len(z_tiles): 828 self.z += 1 829 self.tile_id = 0 830 if self.z >= self.n_slices: 831 break 832 z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z] 833 continue 834 835 slices.append(self.z), tile_ids.append(z_tiles[self.tile_id]) 836 self.tile_id += 1 837 this_batch_size += 1 838 839 self.batch_id += 1 840 return slices, tile_ids 841 842 843def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 844 assert input_.ndim == 3 845 846 shape = input_.shape[1:] 847 tiling = blocking([0, 0], shape, tile_shape) 848 features = f.require_group("features") 849 features.attrs["shape"] = shape 850 features.attrs["tile_shape"] = tile_shape 851 features.attrs["halo"] = halo 852 853 n_tiles_per_plane = tiling.numberOfBlocks 854 n_slices = input_.shape[0] 855 856 msg = "Compute Image Embeddings 3D tiled" 857 if mask is None: 858 n_tiles_total = n_slices * n_tiles_per_plane 859 tiles_in_mask_per_slice = None 860 else: 861 tiles_in_mask_per_slice = {} 862 for z in range(n_slices): 863 tiles_in_mask_per_slice[z] = _get_tiles_in_mask(mask, tiling, halo, z=z) 864 n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values()) 865 msg += " masked" 866 pbar_init(n_tiles_total, msg) 867 868 batch_provider = _BatchProvider(n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size) 869 for slices, tile_ids in batch_provider: 870 batched_images = [] 871 for z, tile_id in zip(slices, tile_ids): 872 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 873 outer_tile = (z,) + tuple( 874 slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end) 875 ) 876 tile_input = _to_image(input_[outer_tile]) 877 batched_images.append(tile_input) 878 879 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 880 _write_batch( 881 features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=slices, n_slices=n_slices 882 ) 883 pbar_update(len(tile_ids)) 884 885 if mask is not None: 886 features.attrs["tiles_in_mask"] = {str(z): per_slice for z, per_slice in tiles_in_mask_per_slice.items()} 887 888 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 889 return features 890 891 892def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 893 # Check if the embeddings are already cached. 894 if save_path is not None and "input_size" in f.attrs: 895 # In this case we load the embeddings. 896 features = f["features"][:] 897 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 898 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 899 # Also set the embeddings. 900 set_precomputed(predictor, image_embeddings) 901 return image_embeddings 902 903 pbar_init(1, "Compute Image Embeddings 2D") 904 # Otherwise we have to compute the embeddings. 905 predictor.reset_image() 906 predictor.set_image(_to_image(input_)) 907 features = predictor.get_image_embedding().cpu().numpy() 908 original_size = predictor.original_size 909 input_size = predictor.input_size 910 pbar_update(1) 911 912 # Save the embeddings if we have a save_path. 913 if save_path is not None: 914 _create_dataset_with_data(f, "features", data=features) 915 _write_embedding_signature( 916 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 917 ) 918 919 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 920 return image_embeddings 921 922 923def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 924 # Check if the features are already computed. 925 if "input_size" in f.attrs: 926 features = f["features"] 927 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 928 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 929 return image_embeddings 930 931 # Otherwise compute them. Note: saving happens automatically because we 932 # always write the features to zarr. If no save path is given we use an in-memory zarr. 933 features = _compute_tiled_features_2d( 934 predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 935 ) 936 image_embeddings = {"features": features, "input_size": None, "original_size": None} 937 return image_embeddings 938 939 940def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size): 941 # Check if the embeddings are already fully cached. 942 if save_path is not None and "input_size" in f.attrs: 943 # In this case we load the embeddings. 944 features = f["features"] if lazy_loading else f["features"][:] 945 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 946 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 947 return image_embeddings 948 949 # Otherwise we have to compute the embeddings. 950 951 # First check if we have a save path or not and set things up accordingly. 952 if save_path is None: 953 features = [] 954 save_features = False 955 partial_features = False 956 else: 957 save_features = True 958 embed_shape = (1, 256, 64, 64) 959 shape = (input_.shape[0],) + embed_shape 960 chunks = (1,) + embed_shape 961 if "features" in f: 962 partial_features = True 963 features = f["features"] 964 if features.shape != shape or features.chunks != chunks: 965 raise RuntimeError("Invalid partial features") 966 else: 967 partial_features = False 968 features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32") 969 970 # Initialize the pbar and batches. 971 n_slices = input_.shape[0] 972 pbar_init(n_slices, "Compute Image Embeddings 3D") 973 n_batches = int(np.ceil(n_slices / batch_size)) 974 975 for batch_id in range(n_batches): 976 z_start = batch_id * batch_size 977 z_stop = min(z_start + batch_size, n_slices) 978 979 batched_images, batched_z = [], [] 980 for z in range(z_start, z_stop): 981 # Skip feature computation in case of partial features in non-zero slice. 982 if partial_features and np.count_nonzero(features[z]) != 0: 983 continue 984 tile_input = _to_image(input_[z]) 985 batched_images.append(tile_input) 986 batched_z.append(z) 987 988 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 989 990 for z, embedding in zip(batched_z, batched_embeddings): 991 embedding = embedding.unsqueeze(0) 992 if save_features: 993 features[z] = embedding.cpu().numpy() 994 else: 995 features.append(embedding.unsqueeze(0)) 996 pbar_update(1) 997 998 if save_features: 999 _write_embedding_signature( 1000 f, input_, predictor, tile_shape=None, halo=None, 1001 input_size=input_sizes[-1], original_size=original_sizes[-1], 1002 ) 1003 else: 1004 # Concatenate across the z axis. 1005 features = torch.cat(features).cpu().numpy() 1006 1007 image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]} 1008 return image_embeddings 1009 1010 1011def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 1012 # Check if the features are already computed. 1013 if "input_size" in f.attrs: 1014 features = f["features"] 1015 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 1016 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 1017 return image_embeddings 1018 1019 # Otherwise compute them. Note: saving happens automatically because we 1020 # always write the features to zarr. If no save path is given we use an in-memory zarr. 1021 features = _compute_tiled_features_3d( 1022 predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask 1023 ) 1024 image_embeddings = {"features": features, "input_size": None, "original_size": None} 1025 return image_embeddings 1026 1027 1028def _compute_data_signature(input_): 1029 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 1030 return data_signature 1031 1032 1033# Create all metadata that is stored along with the embeddings. 1034def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 1035 if data_signature is None: 1036 data_signature = _compute_data_signature(input_) 1037 1038 signature = { 1039 "data_signature": data_signature, 1040 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 1041 "halo": halo if halo is None else list(halo), 1042 "model_type": predictor.model_type, 1043 "model_name": predictor.model_name, 1044 "micro_sam_version": __version__, 1045 "model_hash": getattr(predictor, "_hash", None), 1046 } 1047 return signature 1048 1049 1050# Note: the input size and orginal size are different if embeddings are tiled or not. 1051# That's why we do not include them in the main signature that is being checked 1052# (_get_embedding_signature), but just add it for serialization here. 1053def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 1054 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 1055 signature.update({"input_size": input_size, "original_size": original_size}) 1056 for key, val in signature.items(): 1057 f.attrs[key] = val 1058 1059 1060def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 1061 # We may have an empty zarr file that was already created to save the embeddings in. 1062 # In this case the embeddings will be computed and we don't need to perform any checks. 1063 if "input_size" not in f.attrs: 1064 return 1065 1066 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 1067 for key, val in signature.items(): 1068 # Check whether the key is missing from the attrs or if the value is not matching. 1069 if key not in f.attrs or f.attrs[key] != val: 1070 # These keys were recently added, so we don't want to fail yet if they don't 1071 # match in order to not invalidate previous embedding files. 1072 # Instead we just raise a warning. (For the version we probably also don't want to fail 1073 # i the future since it should not invalidate the embeddings). 1074 if key in ("micro_sam_version", "model_hash", "model_name"): 1075 warnings.warn( 1076 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 1077 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 1078 "But please recompute them if model predictions don't look as expected." 1079 ) 1080 else: 1081 raise RuntimeError( 1082 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 1083 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 1084 ) 1085 1086 1087# Helper function for optional external progress bars. 1088def handle_pbar(verbose, pbar_init, pbar_update): 1089 """@private""" 1090 1091 # Noop to provide dummy functions. 1092 def noop(*args): 1093 pass 1094 1095 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 1096 assert pbar_update is None # avoid inconsistent state of callbacks 1097 1098 # Create our own progress bar and callbacks 1099 pbar = tqdm() 1100 1101 def pbar_init(total, description): 1102 pbar.total = total 1103 pbar.set_description(description) 1104 1105 def pbar_update(update): 1106 pbar.update(update) 1107 1108 def pbar_close(): 1109 pbar.close() 1110 1111 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 1112 assert pbar_update is not None 1113 pbar = None 1114 pbar_close = noop 1115 1116 else: # we are not verbose, do nothing 1117 pbar = None 1118 pbar_init, pbar_update, pbar_close = noop, noop, noop 1119 1120 return pbar, pbar_init, pbar_update, pbar_close 1121 1122 1123def precompute_image_embeddings( 1124 predictor: SamPredictor, 1125 input_: np.ndarray, 1126 save_path: Optional[Union[str, os.PathLike]] = None, 1127 lazy_loading: bool = False, 1128 ndim: Optional[int] = None, 1129 tile_shape: Optional[Tuple[int, int]] = None, 1130 halo: Optional[Tuple[int, int]] = None, 1131 verbose: bool = True, 1132 batch_size: int = 1, 1133 mask: Optional[np.typing.ArrayLike] = None, 1134 pbar_init: Optional[callable] = None, 1135 pbar_update: Optional[callable] = None, 1136) -> ImageEmbeddings: 1137 """Compute the image embeddings (output of the encoder) for the input. 1138 1139 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1140 1141 Args: 1142 predictor: The Segment Anything predictor. 1143 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1144 save_path: Path to save the embeddings in a zarr container. 1145 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1146 lazy_loading: Whether to load all embeddings into memory or return an 1147 object to load them on demand when required. This only has an effect if 'save_path' is given 1148 and if the input is 3 dimensional. By default, set to 'False'. 1149 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1150 By default, set to 'None', i.e. will be computed from the provided `input_`. 1151 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1152 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1153 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1154 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1155 mask: An optional mask to define areas that are ignored in the computation. 1156 The mask will be used within tiled embedding computation and tiles that don't contain any foreground 1157 in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings. 1158 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1159 Can be used together with pbar_update to handle napari progress bar in other thread. 1160 To enables using this function within a threadworker. 1161 pbar_update: Callback to update an external progress bar. 1162 1163 Returns: 1164 The image embeddings. 1165 """ 1166 ndim = input_.ndim if ndim is None else ndim 1167 1168 # Handle the embedding save_path. 1169 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1170 if save_path is None: 1171 f = zarr.group() 1172 1173 # We have a save path and it already exists. Embeddings will be loaded from it, 1174 # check that the saved embeddings in there match the parameters of the function call. 1175 elif os.path.exists(save_path): 1176 f = zarr.open(save_path, mode="a") 1177 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1178 1179 # We have a save path and it does not exist yet. Create the zarr file to which the 1180 # embeddings will then be saved. 1181 else: 1182 f = zarr.open(save_path, mode="a") 1183 1184 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1185 1186 if ndim == 2 and tile_shape is None: 1187 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1188 elif ndim == 2 and tile_shape is not None: 1189 embeddings = _compute_tiled_2d( 1190 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1191 ) 1192 elif ndim == 3 and tile_shape is None: 1193 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1194 elif ndim == 3 and tile_shape is not None: 1195 embeddings = _compute_tiled_3d( 1196 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1197 ) 1198 else: 1199 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1200 1201 pbar_close() 1202 return embeddings 1203 1204 1205def set_precomputed( 1206 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1207) -> SamPredictor: 1208 """Set the precomputed image embeddings for a predictor. 1209 1210 Args: 1211 predictor: The Segment Anything predictor. 1212 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1213 i: Index for the image data. Required if `image` has three spatial dimensions 1214 or a time dimension and two spatial dimensions. 1215 tile_id: Index for the tile. This is required if the embeddings are tiled. 1216 1217 Returns: 1218 The predictor with set features. 1219 """ 1220 if tile_id is not None: 1221 tile_features = image_embeddings["features"][str(tile_id)] 1222 tile_image_embeddings = { 1223 "features": tile_features, 1224 "input_size": tile_features.attrs["input_size"], 1225 "original_size": tile_features.attrs["original_size"] 1226 } 1227 return set_precomputed(predictor, tile_image_embeddings, i=i) 1228 1229 device = predictor.device 1230 features = image_embeddings["features"] 1231 assert features.ndim in (4, 5), f"{features.ndim}" 1232 if features.ndim == 5 and i is None: 1233 raise ValueError("The data is 3D so an index i is needed.") 1234 elif features.ndim == 4 and i is not None: 1235 raise ValueError("The data is 2D so an index is not needed.") 1236 1237 if i is None: 1238 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1239 torch.from_numpy(features[:]).to(device) 1240 else: 1241 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1242 torch.from_numpy(features[i]).to(device) 1243 1244 predictor.original_size = image_embeddings["original_size"] 1245 predictor.input_size = image_embeddings["input_size"] 1246 predictor.is_image_set = True 1247 1248 return predictor 1249 1250 1251# 1252# Misc functionality 1253# 1254 1255 1256def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1257 """Compute the intersection over union of two masks. 1258 1259 Args: 1260 mask1: The first mask. 1261 mask2: The second mask. 1262 1263 Returns: 1264 The intersection over union of the two masks. 1265 """ 1266 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1267 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1268 eps = 1e-7 1269 iou = float(overlap) / (float(union) + eps) 1270 return iou 1271 1272 1273def get_centers_and_bounding_boxes( 1274 segmentation: np.ndarray, mode: str = "v" 1275) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1276 """Returns the center coordinates of the foreground instances in the ground-truth. 1277 1278 Args: 1279 segmentation: The segmentation. 1280 mode: Determines the functionality used for computing the centers. 1281 If 'v', the object's eccentricity centers computed by vigra are used. 1282 If 'p' the object's centroids computed by skimage are used. 1283 1284 Returns: 1285 A dictionary that maps object ids to the corresponding centroid. 1286 A dictionary that maps object_ids to the corresponding bounding box. 1287 """ 1288 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1289 1290 properties = regionprops(segmentation) 1291 1292 if mode == "p": 1293 center_coordinates = {prop.label: prop.centroid for prop in properties} 1294 elif mode == "v": 1295 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1296 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1297 1298 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1299 1300 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1301 return center_coordinates, bbox_coordinates 1302 1303 1304def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1305 """Helper function to load image data from file. 1306 1307 Args: 1308 path: The filepath to the image data. 1309 key: The internal filepath for complex data formats like hdf5. 1310 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1311 1312 Returns: 1313 The image data. 1314 """ 1315 if key is None: 1316 image_data = imageio.imread(path) 1317 else: 1318 with open_file(path, mode="r") as f: 1319 image_data = f[key] 1320 if not lazy_loading: 1321 image_data = image_data[:] 1322 1323 return image_data 1324 1325 1326def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1327 """Convert the segmentation to one-hot encoded masks. 1328 1329 Args: 1330 segmentation: The segmentation. 1331 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1332 By default, computes the number of ids from the provided `segmentation` masks. 1333 1334 Returns: 1335 The one-hot encoded masks. 1336 """ 1337 masks = segmentation.copy() 1338 if segmentation_ids is None: 1339 n_ids = int(segmentation.max()) 1340 1341 else: 1342 msg = "No foreground objects were found." 1343 if len(segmentation_ids) == 0: # The list should not be completely empty. 1344 raise RuntimeError(msg) 1345 1346 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1347 raise RuntimeError(msg) 1348 1349 # the segmentation ids have to be sorted 1350 segmentation_ids = np.sort(segmentation_ids) 1351 1352 # set the non selected objects to zero and relabel sequentially 1353 masks[~np.isin(masks, segmentation_ids)] = 0 1354 masks = relabel_sequential(masks)[0] 1355 n_ids = len(segmentation_ids) 1356 1357 masks = torch.from_numpy(masks) 1358 1359 one_hot_shape = (n_ids + 1,) + masks.shape 1360 masks = masks.unsqueeze(0) # add dimension to scatter 1361 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1362 1363 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1364 masks = masks.unsqueeze(1) 1365 return masks 1366 1367 1368def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1369 """Get a suitable block shape for chunking a given shape. 1370 1371 The primary use for this is determining chunk sizes for 1372 zarr arrays or block shapes for parallelization. 1373 1374 Args: 1375 shape: The image or volume shape. 1376 1377 Returns: 1378 The block shape. 1379 """ 1380 ndim = len(shape) 1381 if ndim == 2: 1382 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1383 elif ndim == 3: 1384 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1385 else: 1386 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1387 1388 return block_shape 1389 1390 1391def micro_sam_info() -> None: 1392 """Display μSAM information using a rich console.""" 1393 import psutil 1394 import platform 1395 import argparse 1396 from rich import progress 1397 from rich.panel import Panel 1398 from rich.table import Table 1399 from rich.console import Console 1400 1401 import torch 1402 import micro_sam 1403 1404 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1405 parser.add_argument( 1406 "--download", nargs="+", metavar=("WHAT", "KIND"), 1407 help="Downloads the pretrained SAM models." 1408 "'--download models' -> downloads all pretrained models; " 1409 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1410 "'--download model/models vit_b_lm' -> downloads a single specified model." 1411 ) 1412 args = parser.parse_args() 1413 1414 # Open up a new console. 1415 console = Console() 1416 1417 # The header for information CLI. 1418 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1419 console.print("-" * console.width) 1420 1421 # μSAM version panel. 1422 console.print( 1423 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1424 ) 1425 1426 # The documentation link panel. 1427 console.print( 1428 Panel( 1429 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1430 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1431 ) 1432 ) 1433 1434 # The publication panel. 1435 console.print( 1436 Panel( 1437 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1438 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1439 ) 1440 ) 1441 1442 # Creating a cache directory when users' run `micro_sam.info`. 1443 cache_dir = get_cache_directory() 1444 os.makedirs(cache_dir, exist_ok=True) 1445 1446 # The cache directory panel. 1447 console.print( 1448 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory") 1449 ) 1450 1451 # We have a simple versioning logic here (which is what I'll follow here for mapping model versions). 1452 available_models = [] 1453 for model_name, model_path in models().urls.items(): # We filter out the decoder models. 1454 if model_name.endswith("decoder"): 1455 continue 1456 1457 if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path: # Valid v1 SAM models. 1458 available_models.append(model_name) 1459 1460 if "https://owncloud.gwdg.de/" in model_path: # Our own hosted models (in their v1 mode quite often) 1461 if model_name == "vit_t": # MobileSAM model. 1462 available_models.append(model_name) 1463 else: 1464 available_models.append(f"{model_name} (v1)") 1465 1466 # Now for our models, the BioImageIO ModelZoo upload structure is such that: 1467 # '/1/files' corresponds to v2 models. 1468 # '/1.1/files' corresponds to v3 models. 1469 # '/1.2/files' corresponds to v4 models. 1470 if "/1/files" in model_path: 1471 available_models.append(f"{model_name} (v2)") 1472 if "/1.1/files" in model_path: 1473 available_models.append(f"{model_name} (v3)") 1474 if "/1.2/files" in model_path: 1475 available_models.append(f"{model_name} (v4)") 1476 1477 model_list = "\n".join(available_models) 1478 1479 # The available models panel. 1480 console.print( 1481 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1482 ) 1483 1484 # The system information table. 1485 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1486 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1487 table.add_column("Property") 1488 table.add_column("Value", style="bold #56B4E9") 1489 table.add_row("System", platform.system()) 1490 table.add_row("Node Name", platform.node()) 1491 table.add_row("Release", platform.release()) 1492 table.add_row("Version", platform.version()) 1493 table.add_row("Machine", platform.machine()) 1494 table.add_row("Processor", platform.processor()) 1495 table.add_row("Platform", platform.platform()) 1496 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1497 console.print(table) 1498 1499 # The device information and check for available GPU acceleration. 1500 default_device = _get_default_device() 1501 1502 if default_device == "cuda": 1503 device_index = torch.cuda.current_device() 1504 device_name = torch.cuda.get_device_name(device_index) 1505 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1506 elif default_device == "mps": 1507 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1508 else: 1509 console.print( 1510 Panel( 1511 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1512 title="Device Information" 1513 ) 1514 ) 1515 1516 # The section allowing to download models. 1517 # NOTE: In future, can be extended to download sample data. 1518 if args.download: 1519 download_provided_args = [t.lower() for t in args.download] 1520 mode, *model_types = download_provided_args 1521 1522 if mode not in {"models", "model"}: 1523 console.print(f"[red]Unknown option for --download: {mode}[/]") 1524 return 1525 1526 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1527 download_list = available_models 1528 else: 1529 download_list = model_types 1530 incorrect_models = [m for m in download_list if m not in available_models] 1531 if incorrect_models: 1532 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1533 return 1534 1535 with progress.Progress( 1536 progress.SpinnerColumn(), 1537 progress.TextColumn("[progress.description]{task.description}"), 1538 progress.BarColumn(bar_width=None), 1539 "[progress.percentage]{task.percentage:>3.0f}%", 1540 progress.TimeRemainingColumn(), 1541 console=console, 1542 ) as prog: 1543 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1544 for model_type in download_list: 1545 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1546 _download_sam_model(model_type=model_type) 1547 prog.advance(task) 1548 1549 console.print(Panel("[bold green] Downloads complete![/]", title="Finished")) 1550 1551 1552# 1553# Functionality to convert mask predictions to an instance segmentation via non-maximum suppression. 1554# The functionality for computing NMS for masks is taken from CellSeg1: 1555# https://github.com/Nuisal/cellseg1/blob/1c027c2568b83494d2662d1fbecec9aafb478ee0/mask_nms.py 1556# 1557 1558 1559def _overlap_matrix(boxes): 1560 x1 = torch.max(boxes[:, None, 0], boxes[:, 0]) 1561 y1 = torch.max(boxes[:, None, 1], boxes[:, 1]) 1562 x2 = torch.min(boxes[:, None, 2], boxes[:, 2]) 1563 y2 = torch.min(boxes[:, None, 3], boxes[:, 3]) 1564 1565 w = torch.clamp(x2 - x1, min=0) 1566 h = torch.clamp(y2 - y1, min=0) 1567 1568 return (w * h) > 0 1569 1570 1571def _calculate_ious_between_pred_masks(masks, boxes, diagonal_value=1): 1572 n_points = masks.shape[0] 1573 m = torch.zeros((n_points, n_points)) 1574 1575 overlap_m = _overlap_matrix(boxes) 1576 1577 for i in range(n_points): 1578 js = torch.where(overlap_m[i])[0] 1579 js_half = js[js > i].to(masks.device) 1580 1581 if len(js_half) > 0: 1582 intersection = torch.logical_and(masks[i], masks[js_half]).sum(dim=(1, 2)) 1583 union = torch.logical_or(masks[i], masks[js_half]).sum(dim=(1, 2)) 1584 iou = intersection / union 1585 m[i, js_half] = iou 1586 1587 m = m + m.T 1588 m.fill_diagonal_(diagonal_value) 1589 return m 1590 1591 1592def _calculate_iomin_between_pred_masks(masks, boxes, eps=1e-6): 1593 overlap_m = _overlap_matrix(boxes) 1594 1595 # Flatten spatial dimensions: (N, H*W) or (N, D*H*W) 1596 N = masks.shape[0] 1597 masks_flat = masks.reshape(N, -1).float() 1598 1599 # Per-mask area 1600 areas = masks_flat.sum(dim=1) # (N,) 1601 1602 # Pairwise intersections via matrix multiplication 1603 # inter[i, j] = sum_k masks_flat[i, k] * masks_flat[j, k] 1604 inter = masks_flat @ masks_flat.t() # (N, N) 1605 1606 # Denominator: min area of the two masks 1607 min_areas = torch.minimum(areas[:, None], areas[None, :]) # (N, N) 1608 1609 # IoMin = intersection / min(area_i, area_j) 1610 iomin = inter / (min_areas + eps) 1611 1612 # Set elements without any overlap explicitly to zero. 1613 iomin[~overlap_m] = 0 1614 return iomin 1615 1616 1617def _batched_mask_nms(masks, boxes, scores, nms_thresh, intersection_over_min): 1618 boxes = ( 1619 boxes.detach() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes) 1620 ).cpu() 1621 scores = ( 1622 scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores) 1623 ).cpu() 1624 masks = ( 1625 masks.detach() if isinstance(masks, torch.Tensor) else torch.tensor(masks) 1626 ).cpu() 1627 1628 if intersection_over_min: 1629 iou_matrix = _calculate_iomin_between_pred_masks(masks, boxes) 1630 else: 1631 iou_matrix = _calculate_ious_between_pred_masks(masks, boxes) 1632 sorted_indices = torch.argsort(scores, descending=True) 1633 1634 keep = [] 1635 while len(sorted_indices) > 0: 1636 i = sorted_indices[0] 1637 keep.append(i) 1638 1639 if len(sorted_indices) == 1: 1640 break 1641 1642 iou_values = iou_matrix[i, sorted_indices[1:]] 1643 mask = iou_values <= nms_thresh 1644 sorted_indices = sorted_indices[1:][mask] 1645 1646 return torch.tensor(keep) 1647 1648 1649def mask_data_to_segmentation( 1650 masks: List[Dict[str, Any]], 1651 shape: Optional[Tuple[int, int]] = None, 1652 min_object_size: int = 0, 1653 max_object_size: Optional[int] = None, 1654 label_masks: bool = True, 1655 with_background: bool = False, 1656 merge_exclusively: bool = True, 1657) -> np.ndarray: 1658 """Convert the output of the automatic mask generation to an instance segmentation. 1659 1660 Args: 1661 masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`, 1662 or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask. 1663 shape: The shape of the output segmentation. If None, it will be derived from the mask input. 1664 If the mask where predicted with tiling then the shape must be given. 1665 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 1666 max_object_size: The maximal size of an object in pixels. 1667 label_masks: Whether to apply connected components to the result before removing small objects. 1668 By default, set to 'True'. 1669 with_background: Whether to remove the largest object, which often covers the background for AMG. 1670 merge_exclusively: Whether to exclude previous merged masks from merging. 1671 1672 Returns: 1673 The instance segmentation. 1674 """ 1675 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 1676 if shape is None: 1677 shape = next(iter(masks))["segmentation"].shape 1678 segmentation = np.zeros(shape, dtype="uint32") 1679 1680 def require_numpy(mask): 1681 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 1682 1683 seg_id = 1 1684 for mask_data in masks: 1685 area = mask_data["area"] 1686 if (area < min_object_size) or (max_object_size is not None and area > max_object_size): 1687 continue 1688 1689 this_mask = require_numpy(mask_data["segmentation"]) 1690 this_seg_id = mask_data.get("seg_id", seg_id) 1691 if "global_bbox" in mask_data: 1692 bb = mask_data["bbox"] 1693 bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]] 1694 global_bb = mask_data["global_bbox"] 1695 global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]] 1696 if merge_exclusively: 1697 this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0) 1698 else: 1699 this_mask = this_mask[bb] 1700 segmentation[global_bb][this_mask] = this_seg_id 1701 else: 1702 if merge_exclusively: 1703 this_mask = np.logical_and(this_mask, segmentation == 0) 1704 segmentation[this_mask] = this_seg_id 1705 seg_id = this_seg_id + 1 1706 1707 block_shape = (512, 512) 1708 if label_masks: 1709 segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype) 1710 segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape) 1711 segmentation = segmentation_cc 1712 1713 seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape) 1714 filter_ids = seg_ids[sizes < min_object_size] 1715 if with_background: 1716 bg_id = seg_ids[np.argmax(sizes)] 1717 filter_ids = np.concatenate([filter_ids, [bg_id]]) 1718 1719 filter_mask = np.zeros(segmentation.shape, dtype="bool") 1720 filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape) 1721 segmentation[filter_mask] = 0 1722 parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0] 1723 1724 return segmentation 1725 1726 1727def apply_nms( 1728 predictions: List[Dict[str, Any]], 1729 min_size: int, 1730 shape: Optional[Tuple[int, int]] = None, 1731 perform_box_nms: bool = False, 1732 nms_thresh: float = 0.9, 1733 max_size: Optional[int] = None, 1734 intersection_over_min: bool = False, 1735) -> np.ndarray: 1736 """Apply non-maximum suppression to mask predictions from a segment anything model. 1737 1738 Args: 1739 predictions: The mask predictions from SAM. 1740 min_size: The minimum mask size to keep in the output. 1741 shape: The shape of the output segmentation. 1742 Has to be passed for predictions obtained from tiling. 1743 perform_box_nms: Whether to perform NMS on the box coordinates or on the masks. 1744 nms_thresh: The threshold for filtering out objects in NMS. 1745 max_size: The maximum mask size to keep in the output. 1746 intersection_over_min: Whether to perform intersection over the minimum overlap shape 1747 or to perform intersection over union. 1748 1749 Returns: 1750 The segmentation obtained from merging the masks left after NMS. 1751 """ 1752 data = amg_utils.MaskData( 1753 masks=torch.cat([pred["segmentation"][None] for pred in predictions], dim=0), 1754 iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]), 1755 ) 1756 data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions])) 1757 data["area"] = [mask.sum() for mask in data["masks"]] 1758 data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions]) 1759 1760 # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from 1761 # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account. 1762 if "global_bbox" in predictions[0]: 1763 if shape is None: 1764 raise ValueError("The output shape 'shape' has to be passed for tiled predictions.") 1765 data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions])) 1766 is_tiled = True 1767 else: 1768 is_tiled = False 1769 1770 if min_size > 0: 1771 keep_by_size = torch.tensor( 1772 [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long, 1773 ) 1774 data.filter(keep_by_size) 1775 1776 if max_size is not None: 1777 keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size]) 1778 data.filter(keep_by_size) 1779 1780 scores = data["iou_preds"] * data["stability_scores"] 1781 if perform_box_nms: 1782 assert not intersection_over_min # not implemented 1783 keep_by_nms = batched_nms( 1784 data["global_boxes"].float() if is_tiled else data["boxes"].float(), 1785 scores, 1786 torch.zeros_like(data["boxes"][:, 0]), # categories 1787 iou_threshold=nms_thresh, 1788 ) 1789 else: 1790 keep_by_nms = _batched_mask_nms( 1791 masks=data["masks"], 1792 boxes=data["global_boxes"].float() if is_tiled else data["boxes"].float(), 1793 scores=scores, 1794 nms_thresh=nms_thresh, 1795 intersection_over_min=intersection_over_min, 1796 ) 1797 data.filter(keep_by_nms) 1798 1799 if is_tiled: 1800 mask_data = [ 1801 {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box} 1802 for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"]) 1803 ] 1804 else: 1805 mask_data = [ 1806 {"segmentation": mask, "area": area, "bbox": box} 1807 for mask, area, box in zip(data["masks"], data["area"], data["boxes"]) 1808 ] 1809 1810 if shape is None: 1811 shape = predictions[0]["segmentation"].shape 1812 if mask_data: 1813 segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size) 1814 else: # In case all objects have been filtered out due to size filtering. 1815 segmentation = np.zeros(shape, dtype="uint32") 1816 1817 return segmentation
63def get_cache_directory() -> None: 64 """Get micro-sam cache directory location. 65 66 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 67 """ 68 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 69 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 70 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
78def microsam_cachedir() -> None: 79 """Return the micro-sam cache directory. 80 81 Returns the top level cache directory for micro-sam models and sample data. 82 83 Every time this function is called, we check for any user updates made to 84 the MICROSAM_CACHEDIR os environment variable since the last time. 85 """ 86 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 87 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
90def models(): 91 """Return the segmentation models registry. 92 93 We recreate the model registry every time this function is called, 94 so any user changes to the default micro-sam cache directory location 95 are respected. 96 """ 97 98 # We use xxhash to compute the hash of the models, see 99 # https://github.com/computational-cell-analytics/micro-sam/issues/283 100 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 101 # To generate the xxh128 hash: 102 # xxh128sum filename 103 encoder_registry = { 104 # The default segment anything models: 105 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 106 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 107 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 108 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 109 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 110 # The current version of our models in the modelzoo. 111 # LM generalist models: 112 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 113 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 114 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 115 # EM models: 116 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 117 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 118 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 119 # Histopathology models: 120 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 121 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 122 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 123 # Medical Imaging models: 124 "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92", 125 } 126 # Additional decoders for instance segmentation. 127 decoder_registry = { 128 # LM generalist models: 129 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 130 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 131 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 132 # EM models: 133 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 134 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 135 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 136 # Histopathology models: 137 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 138 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 139 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 140 # Medical Imaging models: 141 "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed", 142 } 143 registry = {**encoder_registry, **decoder_registry} 144 145 encoder_urls = { 146 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 147 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 148 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 149 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 150 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 151 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 152 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 153 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 154 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 155 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 156 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 157 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 158 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 159 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download", 160 } 161 162 decoder_urls = { 163 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 164 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 165 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 166 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 167 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 168 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 169 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 170 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 171 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 172 "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download", 173 } 174 urls = {**encoder_urls, **decoder_urls} 175 176 models = pooch.create( 177 path=os.path.join(microsam_cachedir(), "models"), 178 base_url="", 179 registry=registry, 180 urls=urls, 181 ) 182 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
204def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 205 """Get the torch device. 206 207 If no device is passed the default device for your system is used. 208 Else it will be checked if the device you have passed is supported. 209 210 Args: 211 device: The input device. By default, selects the best available device supports. 212 213 Returns: 214 The device. 215 """ 216 if device is None or device == "auto": 217 device = _get_default_device() 218 else: 219 device_type = device if isinstance(device, str) else device.type 220 if device_type.lower() == "cuda": 221 if not torch.cuda.is_available(): 222 raise RuntimeError("PyTorch CUDA backend is not available.") 223 elif device_type.lower() == "mps": 224 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 225 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 226 elif device_type.lower() == "cpu": 227 pass # cpu is always available 228 else: 229 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 230 231 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device. By default, selects the best available device supports.
Returns:
The device.
318def get_sam_model( 319 model_type: str = _DEFAULT_MODEL, 320 device: Optional[Union[str, torch.device]] = None, 321 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 322 return_sam: bool = False, 323 return_state: bool = False, 324 peft_kwargs: Optional[Dict] = None, 325 flexible_load_checkpoint: bool = False, 326 progress_bar_factory: Optional[Callable] = None, 327 **model_kwargs, 328) -> SamPredictor: 329 r"""Get the Segment Anything Predictor. 330 331 This function will download the required model or load it from the cached weight file. 332 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 333 The name of the requested model can be set via `model_type`. 334 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 335 for an overview of the available models 336 337 Alternatively this function can also load a model from weights stored in a local filepath. 338 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 339 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 340 a SAM model with vit_b encoder. 341 342 By default the models are downloaded to a folder named 'micro_sam/models' 343 inside your default cache directory, eg: 344 * Mac: ~/Library/Caches/<AppName> 345 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 346 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 347 See the pooch.os_cache() documentation for more details: 348 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 349 350 Args: 351 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 352 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 353 device: The device for the model. If 'None' is provided, will use GPU if available. 354 checkpoint_path: The path to a file with weights that should be used instead of using the 355 weights corresponding to `model_type`. If given, `model_type` must match the architecture 356 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 357 then `model_type` must be given as 'vit_b'. 358 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 359 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 360 peft_kwargs: Keyword arguments for th PEFT wrapper class. 361 If passed 'None', it does not initialize any parameter efficient finetuning. 362 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 363 By default, set to 'False'. 364 progress_bar_factory: A function to create a progress bar for the model download. 365 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 366 367 Returns: 368 The Segment Anything predictor. 369 """ 370 device = get_device(device) 371 372 # We support passing a local filepath to a checkpoint. 373 # In this case we do not download any weights but just use the local weight file, 374 # as it is, without copying it over anywhere or checking it's hashes. 375 376 # checkpoint_path has not been passed, we download a known model and derive the correct 377 # URL from the model_type. If the model_type is invalid pooch will raise an error. 378 _provided_checkpoint_path = checkpoint_path is not None 379 if checkpoint_path is None: 380 checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory) 381 382 # checkpoint_path has been passed, we use it instead of downloading a model. 383 else: 384 # Check if the file exists and raise an error otherwise. 385 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 386 # (If it isn't the model creation will fail below.) 387 if not os.path.exists(checkpoint_path): 388 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 389 model_hash = _compute_hash(checkpoint_path) 390 decoder_path = None 391 392 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 393 # before calling sam_model_registry. 394 abbreviated_model_type = model_type[:5] 395 if abbreviated_model_type not in _MODEL_TYPES: 396 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 397 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 398 raise RuntimeError( 399 "'mobile_sam' is required for the vit-tiny. " 400 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 401 ) 402 403 state, model_state = _load_checkpoint(checkpoint_path) 404 405 if _provided_checkpoint_path: 406 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 407 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 408 from micro_sam.models.build_sam import _validate_model_type 409 _provided_model_type = _validate_model_type(model_state) 410 411 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 412 # Otherwise replace 'abbreviated_model_type' with the later. 413 if abbreviated_model_type != _provided_model_type: 414 # Printing the message below to avoid any filtering of warnings on user's end. 415 print( 416 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 417 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 418 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 419 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 420 ) 421 422 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 423 abbreviated_model_type = _provided_model_type 424 425 # Whether to update parameters necessary to initialize the model 426 if model_kwargs: # Checks whether model_kwargs have been provided or not 427 if abbreviated_model_type == "vit_t": 428 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 429 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 430 431 else: 432 sam = sam_model_registry[abbreviated_model_type]() 433 434 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 435 # Overwrites the SAM model by freezing the backbone and allow PEFT. 436 if peft_kwargs and isinstance(peft_kwargs, dict): 437 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 438 peft_kwargs.pop("quantize", None) 439 440 if abbreviated_model_type == "vit_t": 441 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 442 443 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 444 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 445 if flexible_load_checkpoint: 446 sam = _handle_checkpoint_loading(sam, model_state) 447 else: 448 sam.load_state_dict(model_state) 449 sam.to(device=device) 450 451 predictor = SamPredictor(sam) 452 predictor.model_type = abbreviated_model_type 453 predictor._hash = model_hash 454 predictor.model_name = model_type 455 predictor.checkpoint_path = checkpoint_path 456 457 # Add the decoder to the state if we have one and if the state is returned. 458 if decoder_path is not None and return_state: 459 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 460 461 if return_sam and return_state: 462 return predictor, sam, state 463 if return_sam: 464 return predictor, sam 465 if return_state: 466 return predictor, state 467 return predictor
Get the Segment Anything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
To get a list of all available model names you can call
get_model_names. - device: The device for the model. If 'None' is provided, will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type. If given,model_typemust match the architecture corresponding to the weight file. e.g. if you use weights for SAM withvit_bencoder thenmodel_typemust be given as 'vit_b'. - return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
- return_state: Return the unpickled checkpoint state. By default, set to 'False'.
- peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
- progress_bar_factory: A function to create a progress bar for the model download.
- model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:
The Segment Anything predictor.
503def export_custom_sam_model( 504 checkpoint_path: Union[str, os.PathLike], 505 model_type: str, 506 save_path: Union[str, os.PathLike], 507 with_segmentation_decoder: bool = False, 508 prefix: str = "sam.", 509) -> None: 510 """Export a finetuned Segment Anything Model to the standard model format. 511 512 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 513 514 Args: 515 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 516 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 517 save_path: Where to save the exported model. 518 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 519 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 520 prefix: The prefix to remove from the model parameter keys. 521 """ 522 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 523 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 524 525 # Store the 'decoder_state' as well, if desired. 526 if with_segmentation_decoder: 527 if "decoder_state" not in state: 528 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 529 decoder_state = state["decoder_state"] 530 save_state = {"model_state": model_state, "decoder_state": decoder_state} 531 else: 532 save_state = model_state 533 534 torch.save(save_state, save_path)
Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
- with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
- prefix: The prefix to remove from the model parameter keys.
537def export_custom_qlora_model( 538 checkpoint_path: Optional[Union[str, os.PathLike]], 539 finetuned_path: Union[str, os.PathLike], 540 model_type: str, 541 save_path: Union[str, os.PathLike], 542) -> None: 543 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 544 545 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 546 547 Args: 548 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 549 finetuned_path: The path to the new finetuned model, using QLoRA. 550 model_type: The Segment Anything Model type corresponding to the checkpoint. 551 save_path: Where to save the exported model. 552 """ 553 # Step 1: Get the base SAM model: used to start finetuning from. 554 _, sam = get_sam_model( 555 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 556 ) 557 558 # Step 2: Load the QLoRA-style finetuned model. 559 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 560 561 # Step 3: Identify LoRA layers from QLoRA model. 562 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 563 # - then copy the LoRA layers from the QLoRA model to the new state dict 564 updated_model_state = {} 565 566 modified_attn_layers = set() 567 modified_mlp_layers = set() 568 569 for k, v in ft_model_state.items(): 570 if "blocks." in k: 571 layer_id = int(k.split("blocks.")[1].split(".")[0]) 572 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 573 modified_attn_layers.add(layer_id) 574 updated_model_state[k] = v 575 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 576 modified_mlp_layers.add(layer_id) 577 updated_model_state[k] = v 578 579 # Step 4: Next, we get all the remaining parameters from the base SAM model. 580 for k, v in sam.state_dict().items(): 581 if "blocks." in k: 582 layer_id = int(k.split("blocks.")[1].split(".")[0]) 583 if k.find("attn.qkv.") != -1: 584 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 585 k = k.replace("qkv", "qkv.qkv_proj") 586 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 587 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 588 k = k.replace("mlp.", "mlp.mlp_layer.") 589 updated_model_state[k] = v 590 591 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 592 ft_state['model_state'] = updated_model_state 593 594 # Step 6: Store the new "state" to "save_path" 595 torch.save(ft_state, save_path)
Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.
Arguments:
- checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
- finetuned_path: The path to the new finetuned model, using QLoRA.
- model_type: The Segment Anything Model type corresponding to the checkpoint.
- save_path: Where to save the exported model.
1124def precompute_image_embeddings( 1125 predictor: SamPredictor, 1126 input_: np.ndarray, 1127 save_path: Optional[Union[str, os.PathLike]] = None, 1128 lazy_loading: bool = False, 1129 ndim: Optional[int] = None, 1130 tile_shape: Optional[Tuple[int, int]] = None, 1131 halo: Optional[Tuple[int, int]] = None, 1132 verbose: bool = True, 1133 batch_size: int = 1, 1134 mask: Optional[np.typing.ArrayLike] = None, 1135 pbar_init: Optional[callable] = None, 1136 pbar_update: Optional[callable] = None, 1137) -> ImageEmbeddings: 1138 """Compute the image embeddings (output of the encoder) for the input. 1139 1140 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1141 1142 Args: 1143 predictor: The Segment Anything predictor. 1144 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1145 save_path: Path to save the embeddings in a zarr container. 1146 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1147 lazy_loading: Whether to load all embeddings into memory or return an 1148 object to load them on demand when required. This only has an effect if 'save_path' is given 1149 and if the input is 3 dimensional. By default, set to 'False'. 1150 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1151 By default, set to 'None', i.e. will be computed from the provided `input_`. 1152 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1153 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1154 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1155 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1156 mask: An optional mask to define areas that are ignored in the computation. 1157 The mask will be used within tiled embedding computation and tiles that don't contain any foreground 1158 in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings. 1159 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1160 Can be used together with pbar_update to handle napari progress bar in other thread. 1161 To enables using this function within a threadworker. 1162 pbar_update: Callback to update an external progress bar. 1163 1164 Returns: 1165 The image embeddings. 1166 """ 1167 ndim = input_.ndim if ndim is None else ndim 1168 1169 # Handle the embedding save_path. 1170 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1171 if save_path is None: 1172 f = zarr.group() 1173 1174 # We have a save path and it already exists. Embeddings will be loaded from it, 1175 # check that the saved embeddings in there match the parameters of the function call. 1176 elif os.path.exists(save_path): 1177 f = zarr.open(save_path, mode="a") 1178 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1179 1180 # We have a save path and it does not exist yet. Create the zarr file to which the 1181 # embeddings will then be saved. 1182 else: 1183 f = zarr.open(save_path, mode="a") 1184 1185 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1186 1187 if ndim == 2 and tile_shape is None: 1188 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1189 elif ndim == 2 and tile_shape is not None: 1190 embeddings = _compute_tiled_2d( 1191 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1192 ) 1193 elif ndim == 3 and tile_shape is None: 1194 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1195 elif ndim == 3 and tile_shape is not None: 1196 embeddings = _compute_tiled_3d( 1197 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1198 ) 1199 else: 1200 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1201 1202 pbar_close() 1203 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The Segment Anything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
By default, set to 'None', i.e. will be computed from the provided
input_. - tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Whether to be verbose in the computation. By default, set to 'True'.
- batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
- mask: An optional mask to define areas that are ignored in the computation. The mask will be used within tiled embedding computation and tiles that don't contain any foreground in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
1206def set_precomputed( 1207 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1208) -> SamPredictor: 1209 """Set the precomputed image embeddings for a predictor. 1210 1211 Args: 1212 predictor: The Segment Anything predictor. 1213 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1214 i: Index for the image data. Required if `image` has three spatial dimensions 1215 or a time dimension and two spatial dimensions. 1216 tile_id: Index for the tile. This is required if the embeddings are tiled. 1217 1218 Returns: 1219 The predictor with set features. 1220 """ 1221 if tile_id is not None: 1222 tile_features = image_embeddings["features"][str(tile_id)] 1223 tile_image_embeddings = { 1224 "features": tile_features, 1225 "input_size": tile_features.attrs["input_size"], 1226 "original_size": tile_features.attrs["original_size"] 1227 } 1228 return set_precomputed(predictor, tile_image_embeddings, i=i) 1229 1230 device = predictor.device 1231 features = image_embeddings["features"] 1232 assert features.ndim in (4, 5), f"{features.ndim}" 1233 if features.ndim == 5 and i is None: 1234 raise ValueError("The data is 3D so an index i is needed.") 1235 elif features.ndim == 4 and i is not None: 1236 raise ValueError("The data is 2D so an index is not needed.") 1237 1238 if i is None: 1239 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1240 torch.from_numpy(features[:]).to(device) 1241 else: 1242 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1243 torch.from_numpy(features[i]).to(device) 1244 1245 predictor.original_size = image_embeddings["original_size"] 1246 predictor.input_size = image_embeddings["input_size"] 1247 predictor.is_image_set = True 1248 1249 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
1257def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1258 """Compute the intersection over union of two masks. 1259 1260 Args: 1261 mask1: The first mask. 1262 mask2: The second mask. 1263 1264 Returns: 1265 The intersection over union of the two masks. 1266 """ 1267 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1268 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1269 eps = 1e-7 1270 iou = float(overlap) / (float(union) + eps) 1271 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
1274def get_centers_and_bounding_boxes( 1275 segmentation: np.ndarray, mode: str = "v" 1276) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1277 """Returns the center coordinates of the foreground instances in the ground-truth. 1278 1279 Args: 1280 segmentation: The segmentation. 1281 mode: Determines the functionality used for computing the centers. 1282 If 'v', the object's eccentricity centers computed by vigra are used. 1283 If 'p' the object's centroids computed by skimage are used. 1284 1285 Returns: 1286 A dictionary that maps object ids to the corresponding centroid. 1287 A dictionary that maps object_ids to the corresponding bounding box. 1288 """ 1289 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1290 1291 properties = regionprops(segmentation) 1292 1293 if mode == "p": 1294 center_coordinates = {prop.label: prop.centroid for prop in properties} 1295 elif mode == "v": 1296 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1297 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1298 1299 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1300 1301 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1302 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
1305def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1306 """Helper function to load image data from file. 1307 1308 Args: 1309 path: The filepath to the image data. 1310 key: The internal filepath for complex data formats like hdf5. 1311 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1312 1313 Returns: 1314 The image data. 1315 """ 1316 if key is None: 1317 image_data = imageio.imread(path) 1318 else: 1319 with open_file(path, mode="r") as f: 1320 image_data = f[key] 1321 if not lazy_loading: 1322 image_data = image_data[:] 1323 1324 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
1327def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1328 """Convert the segmentation to one-hot encoded masks. 1329 1330 Args: 1331 segmentation: The segmentation. 1332 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1333 By default, computes the number of ids from the provided `segmentation` masks. 1334 1335 Returns: 1336 The one-hot encoded masks. 1337 """ 1338 masks = segmentation.copy() 1339 if segmentation_ids is None: 1340 n_ids = int(segmentation.max()) 1341 1342 else: 1343 msg = "No foreground objects were found." 1344 if len(segmentation_ids) == 0: # The list should not be completely empty. 1345 raise RuntimeError(msg) 1346 1347 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1348 raise RuntimeError(msg) 1349 1350 # the segmentation ids have to be sorted 1351 segmentation_ids = np.sort(segmentation_ids) 1352 1353 # set the non selected objects to zero and relabel sequentially 1354 masks[~np.isin(masks, segmentation_ids)] = 0 1355 masks = relabel_sequential(masks)[0] 1356 n_ids = len(segmentation_ids) 1357 1358 masks = torch.from_numpy(masks) 1359 1360 one_hot_shape = (n_ids + 1,) + masks.shape 1361 masks = masks.unsqueeze(0) # add dimension to scatter 1362 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1363 1364 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1365 masks = masks.unsqueeze(1) 1366 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
By default, computes the number of ids from the provided
segmentationmasks.
Returns:
The one-hot encoded masks.
1369def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1370 """Get a suitable block shape for chunking a given shape. 1371 1372 The primary use for this is determining chunk sizes for 1373 zarr arrays or block shapes for parallelization. 1374 1375 Args: 1376 shape: The image or volume shape. 1377 1378 Returns: 1379 The block shape. 1380 """ 1381 ndim = len(shape) 1382 if ndim == 2: 1383 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1384 elif ndim == 3: 1385 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1386 else: 1387 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1388 1389 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.
1392def micro_sam_info() -> None: 1393 """Display μSAM information using a rich console.""" 1394 import psutil 1395 import platform 1396 import argparse 1397 from rich import progress 1398 from rich.panel import Panel 1399 from rich.table import Table 1400 from rich.console import Console 1401 1402 import torch 1403 import micro_sam 1404 1405 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1406 parser.add_argument( 1407 "--download", nargs="+", metavar=("WHAT", "KIND"), 1408 help="Downloads the pretrained SAM models." 1409 "'--download models' -> downloads all pretrained models; " 1410 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1411 "'--download model/models vit_b_lm' -> downloads a single specified model." 1412 ) 1413 args = parser.parse_args() 1414 1415 # Open up a new console. 1416 console = Console() 1417 1418 # The header for information CLI. 1419 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1420 console.print("-" * console.width) 1421 1422 # μSAM version panel. 1423 console.print( 1424 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1425 ) 1426 1427 # The documentation link panel. 1428 console.print( 1429 Panel( 1430 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1431 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1432 ) 1433 ) 1434 1435 # The publication panel. 1436 console.print( 1437 Panel( 1438 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1439 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1440 ) 1441 ) 1442 1443 # Creating a cache directory when users' run `micro_sam.info`. 1444 cache_dir = get_cache_directory() 1445 os.makedirs(cache_dir, exist_ok=True) 1446 1447 # The cache directory panel. 1448 console.print( 1449 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory") 1450 ) 1451 1452 # We have a simple versioning logic here (which is what I'll follow here for mapping model versions). 1453 available_models = [] 1454 for model_name, model_path in models().urls.items(): # We filter out the decoder models. 1455 if model_name.endswith("decoder"): 1456 continue 1457 1458 if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path: # Valid v1 SAM models. 1459 available_models.append(model_name) 1460 1461 if "https://owncloud.gwdg.de/" in model_path: # Our own hosted models (in their v1 mode quite often) 1462 if model_name == "vit_t": # MobileSAM model. 1463 available_models.append(model_name) 1464 else: 1465 available_models.append(f"{model_name} (v1)") 1466 1467 # Now for our models, the BioImageIO ModelZoo upload structure is such that: 1468 # '/1/files' corresponds to v2 models. 1469 # '/1.1/files' corresponds to v3 models. 1470 # '/1.2/files' corresponds to v4 models. 1471 if "/1/files" in model_path: 1472 available_models.append(f"{model_name} (v2)") 1473 if "/1.1/files" in model_path: 1474 available_models.append(f"{model_name} (v3)") 1475 if "/1.2/files" in model_path: 1476 available_models.append(f"{model_name} (v4)") 1477 1478 model_list = "\n".join(available_models) 1479 1480 # The available models panel. 1481 console.print( 1482 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1483 ) 1484 1485 # The system information table. 1486 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1487 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1488 table.add_column("Property") 1489 table.add_column("Value", style="bold #56B4E9") 1490 table.add_row("System", platform.system()) 1491 table.add_row("Node Name", platform.node()) 1492 table.add_row("Release", platform.release()) 1493 table.add_row("Version", platform.version()) 1494 table.add_row("Machine", platform.machine()) 1495 table.add_row("Processor", platform.processor()) 1496 table.add_row("Platform", platform.platform()) 1497 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1498 console.print(table) 1499 1500 # The device information and check for available GPU acceleration. 1501 default_device = _get_default_device() 1502 1503 if default_device == "cuda": 1504 device_index = torch.cuda.current_device() 1505 device_name = torch.cuda.get_device_name(device_index) 1506 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1507 elif default_device == "mps": 1508 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1509 else: 1510 console.print( 1511 Panel( 1512 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1513 title="Device Information" 1514 ) 1515 ) 1516 1517 # The section allowing to download models. 1518 # NOTE: In future, can be extended to download sample data. 1519 if args.download: 1520 download_provided_args = [t.lower() for t in args.download] 1521 mode, *model_types = download_provided_args 1522 1523 if mode not in {"models", "model"}: 1524 console.print(f"[red]Unknown option for --download: {mode}[/]") 1525 return 1526 1527 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1528 download_list = available_models 1529 else: 1530 download_list = model_types 1531 incorrect_models = [m for m in download_list if m not in available_models] 1532 if incorrect_models: 1533 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1534 return 1535 1536 with progress.Progress( 1537 progress.SpinnerColumn(), 1538 progress.TextColumn("[progress.description]{task.description}"), 1539 progress.BarColumn(bar_width=None), 1540 "[progress.percentage]{task.percentage:>3.0f}%", 1541 progress.TimeRemainingColumn(), 1542 console=console, 1543 ) as prog: 1544 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1545 for model_type in download_list: 1546 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1547 _download_sam_model(model_type=model_type) 1548 prog.advance(task) 1549 1550 console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
Display μSAM information using a rich console.
1650def mask_data_to_segmentation( 1651 masks: List[Dict[str, Any]], 1652 shape: Optional[Tuple[int, int]] = None, 1653 min_object_size: int = 0, 1654 max_object_size: Optional[int] = None, 1655 label_masks: bool = True, 1656 with_background: bool = False, 1657 merge_exclusively: bool = True, 1658) -> np.ndarray: 1659 """Convert the output of the automatic mask generation to an instance segmentation. 1660 1661 Args: 1662 masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`, 1663 or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask. 1664 shape: The shape of the output segmentation. If None, it will be derived from the mask input. 1665 If the mask where predicted with tiling then the shape must be given. 1666 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 1667 max_object_size: The maximal size of an object in pixels. 1668 label_masks: Whether to apply connected components to the result before removing small objects. 1669 By default, set to 'True'. 1670 with_background: Whether to remove the largest object, which often covers the background for AMG. 1671 merge_exclusively: Whether to exclude previous merged masks from merging. 1672 1673 Returns: 1674 The instance segmentation. 1675 """ 1676 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 1677 if shape is None: 1678 shape = next(iter(masks))["segmentation"].shape 1679 segmentation = np.zeros(shape, dtype="uint32") 1680 1681 def require_numpy(mask): 1682 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 1683 1684 seg_id = 1 1685 for mask_data in masks: 1686 area = mask_data["area"] 1687 if (area < min_object_size) or (max_object_size is not None and area > max_object_size): 1688 continue 1689 1690 this_mask = require_numpy(mask_data["segmentation"]) 1691 this_seg_id = mask_data.get("seg_id", seg_id) 1692 if "global_bbox" in mask_data: 1693 bb = mask_data["bbox"] 1694 bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]] 1695 global_bb = mask_data["global_bbox"] 1696 global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]] 1697 if merge_exclusively: 1698 this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0) 1699 else: 1700 this_mask = this_mask[bb] 1701 segmentation[global_bb][this_mask] = this_seg_id 1702 else: 1703 if merge_exclusively: 1704 this_mask = np.logical_and(this_mask, segmentation == 0) 1705 segmentation[this_mask] = this_seg_id 1706 seg_id = this_seg_id + 1 1707 1708 block_shape = (512, 512) 1709 if label_masks: 1710 segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype) 1711 segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape) 1712 segmentation = segmentation_cc 1713 1714 seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape) 1715 filter_ids = seg_ids[sizes < min_object_size] 1716 if with_background: 1717 bg_id = seg_ids[np.argmax(sizes)] 1718 filter_ids = np.concatenate([filter_ids, [bg_id]]) 1719 1720 filter_mask = np.zeros(segmentation.shape, dtype="bool") 1721 filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape) 1722 segmentation[filter_mask] = 0 1723 parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0] 1724 1725 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by
AutomaticMaskGenerator, other classes frommicro_sam.instance_segmentation, or frommicro_sam.inferencefunctions. Only supported for output_mode=binary_mask. - shape: The shape of the output segmentation. If None, it will be derived from the mask input. If the mask where predicted with tiling then the shape must be given.
- min_object_size: The minimal size of an object in pixels. By default, set to '0'.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
- with_background: Whether to remove the largest object, which often covers the background for AMG.
- merge_exclusively: Whether to exclude previous merged masks from merging.
Returns:
The instance segmentation.
1728def apply_nms( 1729 predictions: List[Dict[str, Any]], 1730 min_size: int, 1731 shape: Optional[Tuple[int, int]] = None, 1732 perform_box_nms: bool = False, 1733 nms_thresh: float = 0.9, 1734 max_size: Optional[int] = None, 1735 intersection_over_min: bool = False, 1736) -> np.ndarray: 1737 """Apply non-maximum suppression to mask predictions from a segment anything model. 1738 1739 Args: 1740 predictions: The mask predictions from SAM. 1741 min_size: The minimum mask size to keep in the output. 1742 shape: The shape of the output segmentation. 1743 Has to be passed for predictions obtained from tiling. 1744 perform_box_nms: Whether to perform NMS on the box coordinates or on the masks. 1745 nms_thresh: The threshold for filtering out objects in NMS. 1746 max_size: The maximum mask size to keep in the output. 1747 intersection_over_min: Whether to perform intersection over the minimum overlap shape 1748 or to perform intersection over union. 1749 1750 Returns: 1751 The segmentation obtained from merging the masks left after NMS. 1752 """ 1753 data = amg_utils.MaskData( 1754 masks=torch.cat([pred["segmentation"][None] for pred in predictions], dim=0), 1755 iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]), 1756 ) 1757 data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions])) 1758 data["area"] = [mask.sum() for mask in data["masks"]] 1759 data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions]) 1760 1761 # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from 1762 # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account. 1763 if "global_bbox" in predictions[0]: 1764 if shape is None: 1765 raise ValueError("The output shape 'shape' has to be passed for tiled predictions.") 1766 data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions])) 1767 is_tiled = True 1768 else: 1769 is_tiled = False 1770 1771 if min_size > 0: 1772 keep_by_size = torch.tensor( 1773 [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long, 1774 ) 1775 data.filter(keep_by_size) 1776 1777 if max_size is not None: 1778 keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size]) 1779 data.filter(keep_by_size) 1780 1781 scores = data["iou_preds"] * data["stability_scores"] 1782 if perform_box_nms: 1783 assert not intersection_over_min # not implemented 1784 keep_by_nms = batched_nms( 1785 data["global_boxes"].float() if is_tiled else data["boxes"].float(), 1786 scores, 1787 torch.zeros_like(data["boxes"][:, 0]), # categories 1788 iou_threshold=nms_thresh, 1789 ) 1790 else: 1791 keep_by_nms = _batched_mask_nms( 1792 masks=data["masks"], 1793 boxes=data["global_boxes"].float() if is_tiled else data["boxes"].float(), 1794 scores=scores, 1795 nms_thresh=nms_thresh, 1796 intersection_over_min=intersection_over_min, 1797 ) 1798 data.filter(keep_by_nms) 1799 1800 if is_tiled: 1801 mask_data = [ 1802 {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box} 1803 for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"]) 1804 ] 1805 else: 1806 mask_data = [ 1807 {"segmentation": mask, "area": area, "bbox": box} 1808 for mask, area, box in zip(data["masks"], data["area"], data["boxes"]) 1809 ] 1810 1811 if shape is None: 1812 shape = predictions[0]["segmentation"].shape 1813 if mask_data: 1814 segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size) 1815 else: # In case all objects have been filtered out due to size filtering. 1816 segmentation = np.zeros(shape, dtype="uint32") 1817 1818 return segmentation
Apply non-maximum suppression to mask predictions from a segment anything model.
Arguments:
- predictions: The mask predictions from SAM.
- min_size: The minimum mask size to keep in the output.
- shape: The shape of the output segmentation. Has to be passed for predictions obtained from tiling.
- perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
- nms_thresh: The threshold for filtering out objects in NMS.
- max_size: The maximum mask size to keep in the output.
- intersection_over_min: Whether to perform intersection over the minimum overlap shape or to perform intersection over union.
Returns:
The segmentation obtained from merging the masks left after NMS.