micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5import os 6import multiprocessing as mp 7import pickle 8import hashlib 9import warnings 10from concurrent import futures 11from pathlib import Path 12from collections import OrderedDict 13from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Callable 14 15import elf.parallel as parallel_impl 16import imageio.v3 as imageio 17import numpy as np 18import pooch 19import segment_anything.utils.amg as amg_utils 20import torch 21import xxhash 22import zarr 23 24from elf.io import open_file 25from bioimage_cpp.utils import Blocking 26from bioimage_cpp.distance import distance_transform 27from bioimage_cpp.segmentation import relabel_sequential 28from skimage.measure import regionprops 29from skimage.segmentation import find_boundaries 30from torchvision.ops.boxes import batched_nms 31 32from .__version__ import __version__ 33from . import models as custom_models 34 35try: 36 # Avoid import warnigns from mobile_sam 37 with warnings.catch_warnings(): 38 warnings.simplefilter("ignore") 39 from mobile_sam import sam_model_registry, SamPredictor 40 VIT_T_SUPPORT = True 41except ImportError: 42 from segment_anything import sam_model_registry, SamPredictor 43 VIT_T_SUPPORT = False 44 45try: 46 from napari.utils import progress as tqdm 47except ImportError: 48 from tqdm import tqdm 49 50# This is the default model used in micro_sam 51# Currently it is set to vit_b_lm 52_DEFAULT_MODEL = "vit_b_lm" 53 54# The valid model types. Each type corresponds to the architecture of the 55# vision transformer used within SAM. 56_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 57 58 59ImageEmbeddings = Dict[str, Any] 60"""@private""" 61 62 63def get_cache_directory() -> None: 64 """Get micro-sam cache directory location. 65 66 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 67 """ 68 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 69 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 70 return cache_directory 71 72 73# 74# Functionality for model download and export 75# 76 77 78def microsam_cachedir() -> None: 79 """Return the micro-sam cache directory. 80 81 Returns the top level cache directory for micro-sam models and sample data. 82 83 Every time this function is called, we check for any user updates made to 84 the MICROSAM_CACHEDIR os environment variable since the last time. 85 """ 86 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 87 return cache_directory 88 89 90def models(): 91 """Return the segmentation models registry. 92 93 We recreate the model registry every time this function is called, 94 so any user changes to the default micro-sam cache directory location 95 are respected. 96 """ 97 98 # We use xxhash to compute the hash of the models, see 99 # https://github.com/computational-cell-analytics/micro-sam/issues/283 100 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 101 # To generate the xxh128 hash: 102 # xxh128sum filename 103 encoder_registry = { 104 # The default segment anything models: 105 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 106 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 107 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 108 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 109 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 110 # The current version of our models in the modelzoo. 111 # LM generalist models: 112 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 113 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 114 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 115 # EM models: 116 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 117 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 118 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 119 # Histopathology models: 120 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 121 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 122 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 123 # Medical Imaging models: 124 "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92", 125 } 126 # Additional decoders for instance segmentation. 127 decoder_registry = { 128 # LM generalist models: 129 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 130 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 131 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 132 # EM models: 133 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 134 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 135 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 136 # Histopathology models: 137 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 138 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 139 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 140 # Medical Imaging models: 141 "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed", 142 } 143 registry = {**encoder_registry, **decoder_registry} 144 145 encoder_urls = { 146 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 147 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 148 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 149 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 150 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 151 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 152 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 153 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 154 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 155 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 156 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 157 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 158 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 159 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download", 160 } 161 162 decoder_urls = { 163 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 164 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 165 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 166 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 167 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 168 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 169 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 170 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 171 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 172 "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download", 173 } 174 urls = {**encoder_urls, **decoder_urls} 175 176 models = pooch.create( 177 path=os.path.join(microsam_cachedir(), "models"), 178 base_url="", 179 registry=registry, 180 urls=urls, 181 ) 182 return models 183 184 185def _get_default_device(): 186 # check that we're in CI and use the CPU if we are 187 # otherwise the tests may run out of memory on MAC if MPS is used. 188 if os.getenv("GITHUB_ACTIONS") == "true": 189 return "cpu" 190 # Use cuda enabled gpu if it's available. 191 if torch.cuda.is_available(): 192 device = "cuda" 193 # As second priority use mps. 194 # See https://pytorch.org/docs/stable/notes/mps.html for details 195 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 196 print("Using apple MPS device.") 197 device = "mps" 198 # Use the CPU as fallback. 199 else: 200 device = "cpu" 201 return device 202 203 204def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 205 """Get the torch device. 206 207 If no device is passed the default device for your system is used. 208 Else it will be checked if the device you have passed is supported. 209 210 Args: 211 device: The input device. By default, selects the best available device supports. 212 213 Returns: 214 The device. 215 """ 216 if device is None or device == "auto": 217 device = _get_default_device() 218 else: 219 device_type = device if isinstance(device, str) else device.type 220 if device_type.lower() == "cuda": 221 if not torch.cuda.is_available(): 222 raise RuntimeError("PyTorch CUDA backend is not available.") 223 elif device_type.lower() == "mps": 224 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 225 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 226 elif device_type.lower() == "cpu": 227 pass # cpu is always available 228 else: 229 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 230 231 return device 232 233 234def _available_devices(): 235 available_devices = [] 236 for i in ["cuda", "mps", "cpu"]: 237 try: 238 device = get_device(i) 239 except RuntimeError: 240 pass 241 else: 242 available_devices.append(device) 243 return available_devices 244 245 246# We write a custom unpickler that skips objects that cannot be found instead of 247# throwing an AttributeError or ModueNotFoundError. 248# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 249# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 250class _CustomUnpickler(pickle.Unpickler): 251 def find_class(self, module, name): 252 try: 253 return super().find_class(module, name) 254 except (AttributeError, ModuleNotFoundError) as e: 255 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 256 return None 257 258 259def _compute_hash(path, chunk_size=8192): 260 hash_obj = xxhash.xxh128() 261 with open(path, "rb") as f: 262 chunk = f.read(chunk_size) 263 while chunk: 264 hash_obj.update(chunk) 265 chunk = f.read(chunk_size) 266 hash_val = hash_obj.hexdigest() 267 return f"xxh128:{hash_val}" 268 269 270# Load the state from a checkpoint. 271# The checkpoint can either contain a sam encoder state 272# or it can be a checkpoint for model finetuning. 273def _load_checkpoint(checkpoint_path): 274 # Over-ride the unpickler with our custom one. 275 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 276 custom_pickle = pickle 277 custom_pickle.Unpickler = _CustomUnpickler 278 279 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 280 if "model_state" in state: 281 # Copy the model weights from torch_em's training format. 282 model_state = state["model_state"] 283 sam_prefix = "sam." 284 model_state = OrderedDict( 285 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 286 ) 287 else: 288 model_state = state 289 290 return state, model_state 291 292 293def _download_sam_model(model_type, progress_bar_factory=None): 294 model_registry = models() 295 296 progress_bar = True 297 # Check if we have to download the model. 298 # If we do and have a progress bar factory, then we over-write the progress bar. 299 if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None: 300 progress_bar = progress_bar_factory(model_type) 301 302 checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar) 303 if not isinstance(progress_bar, bool): # Close the progress bar when the task finishes. 304 progress_bar.close() 305 306 model_hash = model_registry.registry[model_type] 307 308 # If we have a custom model then we may also have a decoder checkpoint. 309 # Download it here, so that we can add it to the state. 310 decoder_name = f"{model_type}_decoder" 311 decoder_path = model_registry.fetch( 312 decoder_name, progressbar=True 313 ) if decoder_name in model_registry.registry else None 314 315 return checkpoint_path, model_hash, decoder_path 316 317 318def get_sam_model( 319 model_type: str = _DEFAULT_MODEL, 320 device: Optional[Union[str, torch.device]] = None, 321 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 322 return_sam: bool = False, 323 return_state: bool = False, 324 peft_kwargs: Optional[Dict] = None, 325 flexible_load_checkpoint: bool = False, 326 progress_bar_factory: Optional[Callable] = None, 327 decoder_path: Optional[Union[str, os.PathLike]] = None, 328 **model_kwargs, 329) -> SamPredictor: 330 r"""Get the Segment Anything Predictor. 331 332 This function will download the required model or load it from the cached weight file. 333 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 334 The name of the requested model can be set via `model_type`. 335 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 336 for an overview of the available models 337 338 Alternatively this function can also load a model from weights stored in a local filepath. 339 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 340 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 341 a SAM model with vit_b encoder. 342 343 By default the models are downloaded to a folder named 'micro_sam/models' 344 inside your default cache directory, eg: 345 * Mac: ~/Library/Caches/<AppName> 346 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 347 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 348 See the pooch.os_cache() documentation for more details: 349 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 350 351 Args: 352 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 353 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 354 device: The device for the model. If 'None' is provided, will use GPU if available. 355 checkpoint_path: The path to a file with weights that should be used instead of using the 356 weights corresponding to `model_type`. If given, `model_type` must match the architecture 357 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 358 then `model_type` must be given as 'vit_b'. 359 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 360 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 361 peft_kwargs: Keyword arguments for th PEFT wrapper class. 362 If passed 'None', it does not initialize any parameter efficient finetuning. 363 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 364 By default, set to 'False'. 365 progress_bar_factory: A function to create a progress bar for the model download. 366 decoder_path: Optional path to weights for a segmentation decoder. If given and 367 `return_state=True`, the decoder state is added to the returned state as 368 'decoder_state'. This can be used to provide decoder-only weights that are 369 separate from the encoder checkpoint. 370 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 371 372 Returns: 373 The Segment Anything predictor. 374 """ 375 device = get_device(device) 376 377 # We support passing a local filepath to a checkpoint. 378 # In this case we do not download any weights but just use the local weight file, 379 # as it is, without copying it over anywhere or checking it's hashes. 380 381 # checkpoint_path has not been passed, we download a known model and derive the correct 382 # URL from the model_type. If the model_type is invalid pooch will raise an error. 383 _provided_checkpoint_path = checkpoint_path is not None 384 if checkpoint_path is None: 385 checkpoint_path, model_hash, downloaded_decoder_path = _download_sam_model(model_type, progress_bar_factory) 386 if decoder_path is None: 387 decoder_path = downloaded_decoder_path 388 389 # checkpoint_path has been passed, we use it instead of downloading a model. 390 else: 391 # Check if the file exists and raise an error otherwise. 392 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 393 # (If it isn't the model creation will fail below.) 394 if not os.path.exists(checkpoint_path): 395 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 396 model_hash = _compute_hash(checkpoint_path) 397 398 if decoder_path is not None and not os.path.exists(decoder_path): 399 raise ValueError(f"Decoder checkpoint at '{decoder_path}' could not be found.") 400 401 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 402 # before calling sam_model_registry. 403 abbreviated_model_type = model_type[:5] 404 if abbreviated_model_type not in _MODEL_TYPES: 405 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 406 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 407 raise RuntimeError( 408 "'mobile_sam' is required for the vit-tiny. " 409 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 410 ) 411 412 state, model_state = _load_checkpoint(checkpoint_path) 413 414 if _provided_checkpoint_path: 415 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 416 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 417 from micro_sam.models.build_sam import _validate_model_type 418 _provided_model_type = _validate_model_type(model_state) 419 420 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 421 # Otherwise replace 'abbreviated_model_type' with the later. 422 if abbreviated_model_type != _provided_model_type: 423 # Printing the message below to avoid any filtering of warnings on user's end. 424 print( 425 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 426 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 427 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 428 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 429 ) 430 431 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 432 abbreviated_model_type = _provided_model_type 433 434 # Whether to update parameters necessary to initialize the model 435 if model_kwargs: # Checks whether model_kwargs have been provided or not 436 if abbreviated_model_type == "vit_t": 437 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 438 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 439 440 else: 441 sam = sam_model_registry[abbreviated_model_type]() 442 443 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 444 # Overwrites the SAM model by freezing the backbone and allow PEFT. 445 if peft_kwargs and isinstance(peft_kwargs, dict): 446 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 447 peft_kwargs.pop("quantize", None) 448 449 if abbreviated_model_type == "vit_t": 450 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 451 452 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 453 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 454 if flexible_load_checkpoint: 455 sam = _handle_checkpoint_loading(sam, model_state) 456 else: 457 sam.load_state_dict(model_state) 458 sam.to(device=device) 459 460 predictor = SamPredictor(sam) 461 predictor.model_type = abbreviated_model_type 462 predictor._hash = model_hash 463 predictor.model_name = model_type 464 predictor.checkpoint_path = checkpoint_path 465 466 # Add the decoder to the state if we have one and if the state is returned. 467 if decoder_path is not None and return_state: 468 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 469 470 if return_sam and return_state: 471 return predictor, sam, state 472 if return_sam: 473 return predictor, sam 474 if return_state: 475 return predictor, state 476 return predictor 477 478 479def _handle_checkpoint_loading(sam, model_state): 480 # Whether to handle the mismatch issues in a bit more elegant way. 481 # eg. while training for multi-class semantic segmentation in the mask encoder, 482 # parameters are updated - leading to "size mismatch" errors 483 484 new_state_dict = {} # for loading matching parameters 485 mismatched_layers = [] # for tracking mismatching parameters 486 487 reference_state = sam.state_dict() 488 489 for k, v in model_state.items(): 490 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 491 if reference_state[k].size() == v.size(): 492 new_state_dict[k] = v 493 else: 494 mismatched_layers.append(k) 495 496 reference_state.update(new_state_dict) 497 498 if len(mismatched_layers) > 0: 499 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 500 501 for mlayer in mismatched_layers: 502 if 'weight' in mlayer: 503 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 504 elif 'bias' in mlayer: 505 reference_state[mlayer].zero_() 506 507 sam.load_state_dict(reference_state) 508 509 return sam 510 511 512def export_custom_sam_model( 513 checkpoint_path: Union[str, os.PathLike], 514 model_type: str, 515 save_path: Union[str, os.PathLike], 516 with_segmentation_decoder: bool = False, 517 prefix: str = "sam.", 518) -> None: 519 """Export a finetuned Segment Anything Model to the standard model format. 520 521 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 522 523 Args: 524 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 525 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 526 save_path: Where to save the exported model. 527 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 528 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 529 prefix: The prefix to remove from the model parameter keys. 530 """ 531 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 532 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 533 534 # Store the 'decoder_state' as well, if desired. 535 if with_segmentation_decoder: 536 if "decoder_state" not in state: 537 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 538 decoder_state = state["decoder_state"] 539 save_state = {"model_state": model_state, "decoder_state": decoder_state} 540 else: 541 save_state = model_state 542 543 torch.save(save_state, save_path) 544 545 546def export_custom_qlora_model( 547 checkpoint_path: Optional[Union[str, os.PathLike]], 548 finetuned_path: Union[str, os.PathLike], 549 model_type: str, 550 save_path: Union[str, os.PathLike], 551) -> None: 552 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 553 554 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 555 556 Args: 557 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 558 finetuned_path: The path to the new finetuned model, using QLoRA. 559 model_type: The Segment Anything Model type corresponding to the checkpoint. 560 save_path: Where to save the exported model. 561 """ 562 # Step 1: Get the base SAM model: used to start finetuning from. 563 _, sam = get_sam_model( 564 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 565 ) 566 567 # Step 2: Load the QLoRA-style finetuned model. 568 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 569 570 # Step 3: Identify LoRA layers from QLoRA model. 571 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 572 # - then copy the LoRA layers from the QLoRA model to the new state dict 573 updated_model_state = {} 574 575 modified_attn_layers = set() 576 modified_mlp_layers = set() 577 578 for k, v in ft_model_state.items(): 579 if "blocks." in k: 580 layer_id = int(k.split("blocks.")[1].split(".")[0]) 581 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 582 modified_attn_layers.add(layer_id) 583 updated_model_state[k] = v 584 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 585 modified_mlp_layers.add(layer_id) 586 updated_model_state[k] = v 587 588 # Step 4: Next, we get all the remaining parameters from the base SAM model. 589 for k, v in sam.state_dict().items(): 590 if "blocks." in k: 591 layer_id = int(k.split("blocks.")[1].split(".")[0]) 592 if k.find("attn.qkv.") != -1: 593 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 594 k = k.replace("qkv", "qkv.qkv_proj") 595 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 596 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 597 k = k.replace("mlp.", "mlp.mlp_layer.") 598 updated_model_state[k] = v 599 600 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 601 ft_state['model_state'] = updated_model_state 602 603 # Step 6: Store the new "state" to "save_path" 604 torch.save(ft_state, save_path) 605 606 607def get_model_names() -> Iterable: 608 model_registry = models() 609 model_names = model_registry.registry.keys() 610 return model_names 611 612 613# 614# Functionality for precomputing image embeddings. 615# 616 617 618def _to_image(image): 619 input_ = image 620 ndim = input_.ndim 621 n_channels = 1 if ndim == 2 else input_.shape[-1] 622 623 # Map the input to three channels. 624 if ndim == 2: # Grayscale image -> replicate channels. 625 input_ = np.concatenate([input_[..., None]] * 3, axis=-1) 626 elif ndim == 3 and n_channels == 1: # Grayscale image -> replicate channels. 627 input_ = np.concatenate([input_] * 3, axis=-1) 628 elif ndim == 3 and n_channels == 2: # Two channels -> add a zero channel. 629 zero_channel = np.zeros(input_.shape[:2] + (1,), dtype=input_.dtype) 630 input_ = np.concatenate([input_, zero_channel], axis=-1) 631 elif input_.ndim == 3 and n_channels == 3: # RGB input -> do nothing. 632 pass 633 elif input_.ndim == 3 and n_channels > 3: # More than three channels -> select first three. 634 warnings.warn(f"You provided an input with {n_channels} channels. Only the first three will be used.") 635 input_ = input_[..., :3] 636 else: 637 raise ValueError( 638 f"Invalid input dimensionality {ndim}. Expect either a 2D input (=grayscale image) " 639 "or a 3D input (= image with channels)." 640 ) 641 assert input_.ndim == 3 and input_.shape[-1] == 3 642 643 # Normalize the input per channel and bring it to uint8. 644 input_ = input_.astype("float32") 645 input_ -= input_.min(axis=(0, 1))[None, None] 646 input_ /= (input_.max(axis=(0, 1))[None, None] + 1e-7) 647 input_ = (input_ * 255).astype("uint8") 648 649 # Explicitly return a numpy array for compatibility with torchvision 650 # because the input_ array could be something like dask array. 651 return np.array(input_) 652 653 654@torch.no_grad 655def _compute_embeddings_batched(predictor, batched_images): 656 predictor.reset_image() 657 batched_tensors, original_sizes, input_sizes = [], [], [] 658 659 # Apply proeprocessing to all images in the batch, and then stack them. 660 # Note: after the transformation the images are all of the same size, 661 # so they can be stacked and processed as a batch, even if the input images were of different size. 662 for image in batched_images: 663 tensor = predictor.transform.apply_image(image) 664 tensor = torch.as_tensor(tensor, device=predictor.device) 665 tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :] 666 667 original_sizes.append(image.shape[:2]) 668 input_sizes.append(tensor.shape[-2:]) 669 670 tensor = predictor.model.preprocess(tensor) 671 batched_tensors.append(tensor) 672 673 batched_tensors = torch.cat(batched_tensors) 674 features = predictor.model.image_encoder(batched_tensors) 675 676 predictor.original_size = original_sizes[-1] 677 predictor.input_size = input_sizes[-1] 678 predictor.features = features[-1] 679 predictor.is_image_set = True 680 681 return features, original_sizes, input_sizes 682 683 684# Wrapper of zarr.create dataset to support zarr v2 and zarr v3. 685def _create_dataset_with_data(group, name, data, chunks=None): 686 zarr_major_version = int(zarr.__version__.split(".")[0]) 687 if chunks is None: 688 chunks = data.shape 689 if zarr_major_version == 2: 690 ds = group.create_dataset(name, data=data, shape=data.shape, chunks=chunks) 691 elif zarr_major_version == 3: 692 ds = group.create_array(name, shape=data.shape, chunks=chunks, dtype=data.dtype) 693 ds[:] = data 694 else: 695 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 696 return ds 697 698 699def _create_dataset_without_data(group, name, shape, dtype, chunks): 700 zarr_major_version = int(zarr.__version__.split(".")[0]) 701 if zarr_major_version == 2: 702 ds = group.create_dataset(name, shape=shape, dtype=dtype, chunks=chunks) 703 elif zarr_major_version == 3: 704 ds = group.create_array(name, shape=shape, chunks=chunks, dtype=dtype) 705 else: 706 raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}") 707 return ds 708 709 710def _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=None, n_slices=None): 711 712 # Pre-create / pre-fetch the datasets if we have slices. 713 # (Dataset creation is not thread-safe) 714 if slices is not None: 715 datasets = {} 716 for tile_id, tile_embeddings, original_size, input_size in zip( 717 tile_ids, batched_embeddings, original_sizes, input_sizes 718 ): 719 ds_name = str(tile_id) 720 if ds_name in datasets: 721 continue 722 if ds_name in features: 723 datasets[ds_name] = features[ds_name] 724 continue 725 shape = (n_slices, 1) + tile_embeddings.shape 726 chunks = (1, 1) + tile_embeddings.shape 727 ds = _create_dataset_without_data(features, ds_name, shape=shape, dtype="float32", chunks=chunks) 728 ds.attrs["original_size"] = original_size 729 ds.attrs["input_size"] = input_size 730 datasets[ds_name] = ds 731 732 def _write_embed(i): 733 ds_name = str(tile_ids[i]) 734 tile_embeddings = batched_embeddings[i].unsqueeze(0) 735 if slices is None: 736 ds = _create_dataset_with_data(features, ds_name, data=tile_embeddings.cpu().numpy()) 737 ds.attrs["original_size"] = original_sizes[i] 738 ds.attrs["input_size"] = input_sizes[i] 739 elif ds_name in features: 740 ds = datasets[ds_name] 741 z = slices[i] 742 ds[z] = tile_embeddings.cpu().numpy() 743 744 n_tiles = len(tile_ids) 745 n_workers = min(mp.cpu_count(), n_tiles) 746 with futures.ThreadPoolExecutor(n_workers) as tp: 747 list(tp.map(_write_embed, range(n_tiles))) 748 749 750def _get_tiles_in_mask(mask, tiling, halo, z=None): 751 def _check_mask(tile_id): 752 tile = tiling.get_block_with_halo(tile_id, list(halo)) 753 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end)) 754 if z is not None: 755 outer_tile = (z,) + outer_tile 756 tile_mask = mask[outer_tile].astype("bool") 757 return None if tile_mask.sum() == 0 else tile_id 758 759 n_threads = mp.cpu_count() 760 with futures.ThreadPoolExecutor(n_threads) as tp: 761 tiles_in_mask = tp.map(_check_mask, range(tiling.number_of_blocks)) 762 return sorted([tile_id for tile_id in tiles_in_mask if tile_id is not None]) 763 764 765def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 766 tiling = Blocking([0, 0], input_.shape[:2], tile_shape) 767 n_tiles = tiling.number_of_blocks 768 769 features = f.require_group("features") 770 features.attrs["shape"] = input_.shape[:2] 771 features.attrs["tile_shape"] = tile_shape 772 features.attrs["halo"] = halo 773 774 n_batches = int(np.ceil(n_tiles / batch_size)) 775 if mask is None: 776 tile_ids_for_batches = [ 777 range(batch_id * batch_size, min((batch_id + 1) * batch_size, n_tiles)) 778 for batch_id in range(n_batches) 779 ] 780 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled") 781 else: 782 tiles_in_mask = _get_tiles_in_mask(mask, tiling, halo) 783 pbar_init(len(tiles_in_mask), "Compute Image Embeddings 2D tiled with mask") 784 tile_ids_for_batches = np.array_split(tiles_in_mask, n_batches) 785 assert len(tile_ids_for_batches) == n_batches 786 787 for tile_ids in tile_ids_for_batches: 788 batched_images = [] 789 for tile_id in tile_ids: 790 tile = tiling.get_block_with_halo(tile_id, list(halo)) 791 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end)) 792 tile_input = _to_image(input_[outer_tile]) 793 batched_images.append(tile_input) 794 795 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 796 _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes) 797 pbar_update(len(tile_ids)) 798 799 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 800 if mask is not None: 801 features.attrs["tiles_in_mask"] = tiles_in_mask 802 803 return features 804 805 806class _BatchProvider: 807 def __init__(self, n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size): 808 if tiles_in_mask_per_slice is None: 809 self.n_tiles_total = n_slices * n_tiles_per_plane 810 else: 811 self.n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values()) 812 813 self.n_batches = int(np.ceil(self.n_tiles_total / batch_size)) 814 self.n_slices = n_slices 815 self.n_tiles_per_plane = n_tiles_per_plane 816 self.tiles_in_mask_per_slice = tiles_in_mask_per_slice 817 self.batch_size = batch_size 818 819 # Iter variables. 820 self.batch_id = 0 821 self.z = 0 822 self.tile_id = 0 823 824 def __iter__(self): 825 return self 826 827 def __next__(self): 828 if self.batch_id >= self.n_batches: 829 raise StopIteration 830 831 z_list = list(range(self.n_tiles_per_plane)) 832 z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z] 833 834 slices, tile_ids = [], [] 835 this_batch_size = 0 836 while this_batch_size < self.batch_size: 837 if self.tile_id == len(z_tiles): 838 self.z += 1 839 self.tile_id = 0 840 if self.z >= self.n_slices: 841 break 842 z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z] 843 continue 844 845 slices.append(self.z), tile_ids.append(z_tiles[self.tile_id]) 846 self.tile_id += 1 847 this_batch_size += 1 848 849 self.batch_id += 1 850 return slices, tile_ids 851 852 853def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 854 assert input_.ndim == 3 855 856 shape = input_.shape[1:] 857 tiling = Blocking([0, 0], shape, tile_shape) 858 features = f.require_group("features") 859 features.attrs["shape"] = shape 860 features.attrs["tile_shape"] = tile_shape 861 features.attrs["halo"] = halo 862 863 n_tiles_per_plane = tiling.number_of_blocks 864 n_slices = input_.shape[0] 865 866 msg = "Compute Image Embeddings 3D tiled" 867 if mask is None: 868 n_tiles_total = n_slices * n_tiles_per_plane 869 tiles_in_mask_per_slice = None 870 else: 871 tiles_in_mask_per_slice = {} 872 for z in range(n_slices): 873 tiles_in_mask_per_slice[z] = _get_tiles_in_mask(mask, tiling, halo, z=z) 874 n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values()) 875 msg += " masked" 876 pbar_init(n_tiles_total, msg) 877 878 batch_provider = _BatchProvider(n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size) 879 for slices, tile_ids in batch_provider: 880 batched_images = [] 881 for z, tile_id in zip(slices, tile_ids): 882 tile = tiling.get_block_with_halo(tile_id, list(halo)) 883 outer_tile = (z,) + tuple( 884 slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end) 885 ) 886 tile_input = _to_image(input_[outer_tile]) 887 batched_images.append(tile_input) 888 889 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 890 _write_batch( 891 features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=slices, n_slices=n_slices 892 ) 893 pbar_update(len(tile_ids)) 894 895 if mask is not None: 896 features.attrs["tiles_in_mask"] = {str(z): per_slice for z, per_slice in tiles_in_mask_per_slice.items()} 897 898 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 899 return features 900 901 902def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 903 # Check if the embeddings are already cached. 904 if save_path is not None and "input_size" in f.attrs: 905 # In this case we load the embeddings. 906 features = f["features"][:] 907 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 908 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 909 # Also set the embeddings. 910 set_precomputed(predictor, image_embeddings) 911 return image_embeddings 912 913 pbar_init(1, "Compute Image Embeddings 2D") 914 # Otherwise we have to compute the embeddings. 915 predictor.reset_image() 916 predictor.set_image(_to_image(input_)) 917 features = predictor.get_image_embedding().cpu().numpy() 918 original_size = predictor.original_size 919 input_size = predictor.input_size 920 pbar_update(1) 921 922 # Save the embeddings if we have a save_path. 923 if save_path is not None: 924 _create_dataset_with_data(f, "features", data=features) 925 _write_embedding_signature( 926 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 927 ) 928 929 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 930 return image_embeddings 931 932 933def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 934 # Check if the features are already computed. 935 if "input_size" in f.attrs: 936 features = f["features"] 937 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 938 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 939 return image_embeddings 940 941 # Otherwise compute them. Note: saving happens automatically because we 942 # always write the features to zarr. If no save path is given we use an in-memory zarr. 943 features = _compute_tiled_features_2d( 944 predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 945 ) 946 image_embeddings = {"features": features, "input_size": None, "original_size": None} 947 return image_embeddings 948 949 950def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size): 951 # Check if the embeddings are already fully cached. 952 if save_path is not None and "input_size" in f.attrs: 953 # In this case we load the embeddings. 954 features = f["features"] if lazy_loading else f["features"][:] 955 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 956 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 957 return image_embeddings 958 959 # Otherwise we have to compute the embeddings. 960 961 # First check if we have a save path or not and set things up accordingly. 962 if save_path is None: 963 features = [] 964 save_features = False 965 partial_features = False 966 else: 967 save_features = True 968 embed_shape = (1, 256, 64, 64) 969 shape = (input_.shape[0],) + embed_shape 970 chunks = (1,) + embed_shape 971 if "features" in f: 972 partial_features = True 973 features = f["features"] 974 if features.shape != shape or features.chunks != chunks: 975 raise RuntimeError("Invalid partial features") 976 else: 977 partial_features = False 978 features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32") 979 980 # Initialize the pbar and batches. 981 n_slices = input_.shape[0] 982 pbar_init(n_slices, "Compute Image Embeddings 3D") 983 n_batches = int(np.ceil(n_slices / batch_size)) 984 985 for batch_id in range(n_batches): 986 z_start = batch_id * batch_size 987 z_stop = min(z_start + batch_size, n_slices) 988 989 batched_images, batched_z = [], [] 990 for z in range(z_start, z_stop): 991 # Skip feature computation in case of partial features in non-zero slice. 992 if partial_features and np.count_nonzero(features[z]) != 0: 993 continue 994 tile_input = _to_image(input_[z]) 995 batched_images.append(tile_input) 996 batched_z.append(z) 997 998 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 999 1000 for z, embedding in zip(batched_z, batched_embeddings): 1001 embedding = embedding.unsqueeze(0) 1002 if save_features: 1003 features[z] = embedding.cpu().numpy() 1004 else: 1005 features.append(embedding.unsqueeze(0)) 1006 pbar_update(1) 1007 1008 if save_features: 1009 _write_embedding_signature( 1010 f, input_, predictor, tile_shape=None, halo=None, 1011 input_size=input_sizes[-1], original_size=original_sizes[-1], 1012 ) 1013 else: 1014 # Concatenate across the z axis. 1015 features = torch.cat(features).cpu().numpy() 1016 1017 image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]} 1018 return image_embeddings 1019 1020 1021def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask): 1022 # Check if the features are already computed. 1023 if "input_size" in f.attrs: 1024 features = f["features"] 1025 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 1026 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 1027 return image_embeddings 1028 1029 # Otherwise compute them. Note: saving happens automatically because we 1030 # always write the features to zarr. If no save path is given we use an in-memory zarr. 1031 features = _compute_tiled_features_3d( 1032 predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask 1033 ) 1034 image_embeddings = {"features": features, "input_size": None, "original_size": None} 1035 return image_embeddings 1036 1037 1038def _compute_data_signature(input_): 1039 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 1040 return data_signature 1041 1042 1043# Create all metadata that is stored along with the embeddings. 1044def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 1045 if data_signature is None: 1046 data_signature = _compute_data_signature(input_) 1047 1048 signature = { 1049 "data_signature": data_signature, 1050 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 1051 "halo": halo if halo is None else list(halo), 1052 "model_type": predictor.model_type, 1053 "model_name": predictor.model_name, 1054 "micro_sam_version": __version__, 1055 "model_hash": getattr(predictor, "_hash", None), 1056 } 1057 return signature 1058 1059 1060# Note: the input size and orginal size are different if embeddings are tiled or not. 1061# That's why we do not include them in the main signature that is being checked 1062# (_get_embedding_signature), but just add it for serialization here. 1063def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 1064 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 1065 signature.update({"input_size": input_size, "original_size": original_size}) 1066 for key, val in signature.items(): 1067 f.attrs[key] = val 1068 1069 1070def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 1071 # We may have an empty zarr file that was already created to save the embeddings in. 1072 # In this case the embeddings will be computed and we don't need to perform any checks. 1073 if "input_size" not in f.attrs: 1074 return 1075 1076 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 1077 for key, val in signature.items(): 1078 # Check whether the key is missing from the attrs or if the value is not matching. 1079 if key not in f.attrs or f.attrs[key] != val: 1080 # These keys were recently added, so we don't want to fail yet if they don't 1081 # match in order to not invalidate previous embedding files. 1082 # Instead we just raise a warning. (For the version we probably also don't want to fail 1083 # i the future since it should not invalidate the embeddings). 1084 if key in ("micro_sam_version", "model_hash", "model_name"): 1085 warnings.warn( 1086 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 1087 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 1088 "But please recompute them if model predictions don't look as expected." 1089 ) 1090 else: 1091 raise RuntimeError( 1092 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 1093 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 1094 ) 1095 1096 1097# Helper function for optional external progress bars. 1098def handle_pbar(verbose, pbar_init, pbar_update): 1099 """@private""" 1100 1101 # Noop to provide dummy functions. 1102 def noop(*args): 1103 pass 1104 1105 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 1106 assert pbar_update is None # avoid inconsistent state of callbacks 1107 1108 # Create our own progress bar and callbacks 1109 pbar = tqdm() 1110 1111 def pbar_init(total, description): 1112 pbar.total = total 1113 pbar.set_description(description) 1114 1115 def pbar_update(update): 1116 pbar.update(update) 1117 1118 def pbar_close(): 1119 pbar.close() 1120 1121 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 1122 assert pbar_update is not None 1123 pbar = None 1124 pbar_close = noop 1125 1126 else: # we are not verbose, do nothing 1127 pbar = None 1128 pbar_init, pbar_update, pbar_close = noop, noop, noop 1129 1130 return pbar, pbar_init, pbar_update, pbar_close 1131 1132 1133def precompute_image_embeddings( 1134 predictor: SamPredictor, 1135 input_: np.ndarray, 1136 save_path: Optional[Union[str, os.PathLike]] = None, 1137 lazy_loading: bool = False, 1138 ndim: Optional[int] = None, 1139 tile_shape: Optional[Tuple[int, int]] = None, 1140 halo: Optional[Tuple[int, int]] = None, 1141 verbose: bool = True, 1142 batch_size: int = 1, 1143 mask: Optional[np.typing.ArrayLike] = None, 1144 pbar_init: Optional[callable] = None, 1145 pbar_update: Optional[callable] = None, 1146) -> ImageEmbeddings: 1147 """Compute the image embeddings (output of the encoder) for the input. 1148 1149 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1150 1151 Args: 1152 predictor: The Segment Anything predictor. 1153 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1154 save_path: Path to save the embeddings in a zarr container. 1155 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1156 lazy_loading: Whether to load all embeddings into memory or return an 1157 object to load them on demand when required. This only has an effect if 'save_path' is given 1158 and if the input is 3 dimensional. By default, set to 'False'. 1159 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1160 By default, set to 'None', i.e. will be computed from the provided `input_`. 1161 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1162 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1163 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1164 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1165 mask: An optional mask to define areas that are ignored in the computation. 1166 The mask will be used within tiled embedding computation and tiles that don't contain any foreground 1167 in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings. 1168 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1169 Can be used together with pbar_update to handle napari progress bar in other thread. 1170 To enable using this function within a threadworker. 1171 pbar_update: Callback to update an external progress bar. 1172 1173 Returns: 1174 The image embeddings. 1175 """ 1176 ndim = input_.ndim if ndim is None else ndim 1177 1178 # Handle the embedding save_path. 1179 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1180 if save_path is None: 1181 f = zarr.group() 1182 1183 # We have a save path and it already exists. Embeddings will be loaded from it, 1184 # check that the saved embeddings in there match the parameters of the function call. 1185 elif os.path.exists(save_path): 1186 f = zarr.open(save_path, mode="a") 1187 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1188 1189 # We have a save path and it does not exist yet. Create the zarr file to which the 1190 # embeddings will then be saved. 1191 else: 1192 f = zarr.open(save_path, mode="a") 1193 1194 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1195 1196 if ndim == 2 and tile_shape is None: 1197 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1198 elif ndim == 2 and tile_shape is not None: 1199 embeddings = _compute_tiled_2d( 1200 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1201 ) 1202 elif ndim == 3 and tile_shape is None: 1203 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1204 elif ndim == 3 and tile_shape is not None: 1205 embeddings = _compute_tiled_3d( 1206 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1207 ) 1208 else: 1209 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1210 1211 pbar_close() 1212 return embeddings 1213 1214 1215def set_precomputed( 1216 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1217) -> SamPredictor: 1218 """Set the precomputed image embeddings for a predictor. 1219 1220 Args: 1221 predictor: The Segment Anything predictor. 1222 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1223 i: Index for the image data. Required if `image` has three spatial dimensions 1224 or a time dimension and two spatial dimensions. 1225 tile_id: Index for the tile. This is required if the embeddings are tiled. 1226 1227 Returns: 1228 The predictor with set features. 1229 """ 1230 if tile_id is not None: 1231 tile_features = image_embeddings["features"][str(tile_id)] 1232 tile_image_embeddings = { 1233 "features": tile_features, 1234 "input_size": tile_features.attrs["input_size"], 1235 "original_size": tile_features.attrs["original_size"] 1236 } 1237 return set_precomputed(predictor, tile_image_embeddings, i=i) 1238 1239 device = predictor.device 1240 features = image_embeddings["features"] 1241 assert features.ndim in (4, 5), f"{features.ndim}" 1242 if features.ndim == 5 and i is None: 1243 raise ValueError("The data is 3D so an index i is needed.") 1244 elif features.ndim == 4 and i is not None: 1245 raise ValueError("The data is 2D so an index is not needed.") 1246 1247 if i is None: 1248 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1249 torch.from_numpy(features[:]).to(device) 1250 else: 1251 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1252 torch.from_numpy(features[i]).to(device) 1253 1254 predictor.original_size = image_embeddings["original_size"] 1255 predictor.input_size = image_embeddings["input_size"] 1256 predictor.is_image_set = True 1257 1258 return predictor 1259 1260 1261# 1262# Misc functionality 1263# 1264 1265 1266def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1267 """Compute the intersection over union of two masks. 1268 1269 Args: 1270 mask1: The first mask. 1271 mask2: The second mask. 1272 1273 Returns: 1274 The intersection over union of the two masks. 1275 """ 1276 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1277 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1278 eps = 1e-7 1279 iou = float(overlap) / (float(union) + eps) 1280 return iou 1281 1282 1283def get_centers_and_bounding_boxes( 1284 segmentation: np.ndarray, mode: str = "v" 1285) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1286 """Returns the center coordinates of the foreground instances in the ground-truth. 1287 1288 Args: 1289 segmentation: The segmentation. 1290 mode: Determines the functionality used for computing the centers. 1291 If 'v', the point of maximal distance to the object boundary is used as center. 1292 This center is guaranteed to lie inside the object, also for concave shapes. 1293 If 'p' the object's centroids computed by skimage are used. 1294 1295 Returns: 1296 A dictionary that maps object ids to the corresponding centroid. 1297 A dictionary that maps object_ids to the corresponding bounding box. 1298 """ 1299 assert mode in ["p", "v"], "Choose either 'p' for regionprops centroids or 'v' for distance-based centers" 1300 1301 properties = regionprops(segmentation) 1302 1303 if mode == "p": 1304 center_coordinates = {prop.label: prop.centroid for prop in properties} 1305 elif mode == "v": 1306 # Use the point of maximal distance to the object boundary as the center. 1307 # In contrast to the centroid, this point is guaranteed to lie inside the object, 1308 # also for concave shapes. This replaces vigra.filters.eccentricityCenters. 1309 # Compute the boundaries and a single distance transform for the whole 1310 # segmentation, instead of one distance transform per object. 1311 ndim = segmentation.ndim 1312 # Pad so objects touching the image border also get a boundary there, 1313 # matching a per-object padded distance transform. 1314 padded = np.pad(segmentation, 1) 1315 boundaries = find_boundaries(padded, mode="inner") 1316 distances = distance_transform(boundaries == 0) 1317 1318 center_coordinates = {} 1319 for prop in properties: 1320 bbox = prop.bbox 1321 # Slice the global distance field to this object's bbox (shifted by the 1322 # pad of 1) and restrict the argmax to the object's own pixels. 1323 region = distances[tuple(slice(b + 1, e + 1) for b, e in zip(bbox[:ndim], bbox[ndim:]))] 1324 masked = np.where(prop.image, region, -1.0) 1325 center_local = np.unravel_index(int(np.argmax(masked)), masked.shape) 1326 center_coordinates[prop.label] = tuple(int(c + o) for c, o in zip(center_local, bbox[:ndim])) 1327 1328 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1329 1330 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1331 return center_coordinates, bbox_coordinates 1332 1333 1334def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1335 """Helper function to load image data from file. 1336 1337 Args: 1338 path: The filepath to the image data. 1339 key: The internal filepath for complex data formats like hdf5. 1340 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1341 1342 Returns: 1343 The image data. 1344 """ 1345 if key is None: 1346 image_data = imageio.imread(path) 1347 else: 1348 with open_file(path, mode="r") as f: 1349 image_data = f[key] 1350 if not lazy_loading: 1351 image_data = image_data[:] 1352 1353 return image_data 1354 1355 1356def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1357 """Convert the segmentation to one-hot encoded masks. 1358 1359 Args: 1360 segmentation: The segmentation. 1361 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1362 By default, computes the number of ids from the provided `segmentation` masks. 1363 1364 Returns: 1365 The one-hot encoded masks. 1366 """ 1367 masks = segmentation.copy() 1368 if segmentation_ids is None: 1369 n_ids = int(segmentation.max()) 1370 1371 else: 1372 msg = "No foreground objects were found." 1373 if len(segmentation_ids) == 0: # The list should not be completely empty. 1374 raise RuntimeError(msg) 1375 1376 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1377 raise RuntimeError(msg) 1378 1379 # the segmentation ids have to be sorted 1380 segmentation_ids = np.sort(segmentation_ids) 1381 1382 # set the non selected objects to zero and relabel sequentially 1383 masks[~np.isin(masks, segmentation_ids)] = 0 1384 masks = relabel_sequential(masks)[0] 1385 n_ids = len(segmentation_ids) 1386 1387 masks = torch.from_numpy(masks) 1388 1389 one_hot_shape = (n_ids + 1,) + masks.shape 1390 masks = masks.unsqueeze(0) # add dimension to scatter 1391 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1392 1393 # add the extra singleton dimension to get shape NUM_OBJECTS x 1 x H x W 1394 masks = masks.unsqueeze(1) 1395 return masks 1396 1397 1398def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1399 """Get a suitable block shape for chunking a given shape. 1400 1401 The primary use for this is determining chunk sizes for 1402 zarr arrays or block shapes for parallelization. 1403 1404 Args: 1405 shape: The image or volume shape. 1406 1407 Returns: 1408 The block shape. 1409 """ 1410 ndim = len(shape) 1411 if ndim == 2: 1412 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1413 elif ndim == 3: 1414 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1415 else: 1416 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1417 1418 return block_shape 1419 1420 1421def micro_sam_info() -> None: 1422 """Display μSAM information using a rich console.""" 1423 import psutil 1424 import platform 1425 import argparse 1426 from rich import progress 1427 from rich.panel import Panel 1428 from rich.table import Table 1429 from rich.console import Console 1430 1431 import torch 1432 import micro_sam 1433 1434 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1435 parser.add_argument( 1436 "--download", nargs="+", metavar=("WHAT", "KIND"), 1437 help="Downloads the pretrained SAM models." 1438 "'--download models' -> downloads all pretrained models; " 1439 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1440 "'--download model/models vit_b_lm' -> downloads a single specified model." 1441 ) 1442 args = parser.parse_args() 1443 1444 # Open up a new console. 1445 console = Console() 1446 1447 # The header for information CLI. 1448 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1449 console.print("-" * console.width) 1450 1451 # μSAM version panel. 1452 console.print( 1453 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1454 ) 1455 1456 # The documentation link panel. 1457 console.print( 1458 Panel( 1459 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1460 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1461 ) 1462 ) 1463 1464 # The publication panel. 1465 console.print( 1466 Panel( 1467 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1468 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1469 ) 1470 ) 1471 1472 # Creating a cache directory when users' run `micro_sam.info`. 1473 cache_dir = get_cache_directory() 1474 os.makedirs(cache_dir, exist_ok=True) 1475 1476 # The cache directory panel. 1477 console.print( 1478 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory") 1479 ) 1480 1481 # We have a simple versioning logic here (which is what I'll follow here for mapping model versions). 1482 available_models = [] 1483 for model_name, model_path in models().urls.items(): # We filter out the decoder models. 1484 if model_name.endswith("decoder"): 1485 continue 1486 1487 if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path: # Valid v1 SAM models. 1488 available_models.append(model_name) 1489 1490 if "https://owncloud.gwdg.de/" in model_path: # Our own hosted models (in their v1 mode quite often) 1491 if model_name == "vit_t": # MobileSAM model. 1492 available_models.append(model_name) 1493 else: 1494 available_models.append(f"{model_name} (v1)") 1495 1496 # Now for our models, the BioImageIO ModelZoo upload structure is such that: 1497 # '/1/files' corresponds to v2 models. 1498 # '/1.1/files' corresponds to v3 models. 1499 # '/1.2/files' corresponds to v4 models. 1500 if "/1/files" in model_path: 1501 available_models.append(f"{model_name} (v2)") 1502 if "/1.1/files" in model_path: 1503 available_models.append(f"{model_name} (v3)") 1504 if "/1.2/files" in model_path: 1505 available_models.append(f"{model_name} (v4)") 1506 1507 model_list = "\n".join(available_models) 1508 1509 # The available models panel. 1510 console.print( 1511 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1512 ) 1513 1514 # The system information table. 1515 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1516 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1517 table.add_column("Property") 1518 table.add_column("Value", style="bold #56B4E9") 1519 table.add_row("System", platform.system()) 1520 table.add_row("Node Name", platform.node()) 1521 table.add_row("Release", platform.release()) 1522 table.add_row("Version", platform.version()) 1523 table.add_row("Machine", platform.machine()) 1524 table.add_row("Processor", platform.processor()) 1525 table.add_row("Platform", platform.platform()) 1526 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1527 console.print(table) 1528 1529 # The device information and check for available GPU acceleration. 1530 default_device = _get_default_device() 1531 1532 if default_device == "cuda": 1533 device_index = torch.cuda.current_device() 1534 device_name = torch.cuda.get_device_name(device_index) 1535 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1536 elif default_device == "mps": 1537 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1538 else: 1539 console.print( 1540 Panel( 1541 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1542 title="Device Information" 1543 ) 1544 ) 1545 1546 # The section allowing to download models. 1547 # NOTE: In future, can be extended to download sample data. 1548 if args.download: 1549 download_provided_args = [t.lower() for t in args.download] 1550 mode, *model_types = download_provided_args 1551 1552 if mode not in {"models", "model"}: 1553 console.print(f"[red]Unknown option for --download: {mode}[/]") 1554 return 1555 1556 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1557 download_list = available_models 1558 else: 1559 download_list = model_types 1560 incorrect_models = [m for m in download_list if m not in available_models] 1561 if incorrect_models: 1562 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1563 return 1564 1565 with progress.Progress( 1566 progress.SpinnerColumn(), 1567 progress.TextColumn("[progress.description]{task.description}"), 1568 progress.BarColumn(bar_width=None), 1569 "[progress.percentage]{task.percentage:>3.0f}%", 1570 progress.TimeRemainingColumn(), 1571 console=console, 1572 ) as prog: 1573 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1574 for model_type in download_list: 1575 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1576 _download_sam_model(model_type=model_type) 1577 prog.advance(task) 1578 1579 console.print(Panel("[bold green] Downloads complete![/]", title="Finished")) 1580 1581 1582# 1583# Functionality to convert mask predictions to an instance segmentation via non-maximum suppression. 1584# The functionality for computing NMS for masks is taken from CellSeg1: 1585# https://github.com/Nuisal/cellseg1/blob/1c027c2568b83494d2662d1fbecec9aafb478ee0/mask_nms.py 1586# 1587 1588 1589def _overlap_matrix(boxes): 1590 x1 = torch.max(boxes[:, None, 0], boxes[:, 0]) 1591 y1 = torch.max(boxes[:, None, 1], boxes[:, 1]) 1592 x2 = torch.min(boxes[:, None, 2], boxes[:, 2]) 1593 y2 = torch.min(boxes[:, None, 3], boxes[:, 3]) 1594 1595 w = torch.clamp(x2 - x1, min=0) 1596 h = torch.clamp(y2 - y1, min=0) 1597 1598 return (w * h) > 0 1599 1600 1601def _calculate_ious_between_pred_masks(masks, boxes, diagonal_value=1): 1602 n_points = masks.shape[0] 1603 m = torch.zeros((n_points, n_points)) 1604 1605 overlap_m = _overlap_matrix(boxes) 1606 1607 for i in range(n_points): 1608 js = torch.where(overlap_m[i])[0] 1609 js_half = js[js > i].to(masks.device) 1610 1611 if len(js_half) > 0: 1612 intersection = torch.logical_and(masks[i], masks[js_half]).sum(dim=(1, 2)) 1613 union = torch.logical_or(masks[i], masks[js_half]).sum(dim=(1, 2)) 1614 iou = intersection / union 1615 m[i, js_half] = iou 1616 1617 m = m + m.T 1618 m.fill_diagonal_(diagonal_value) 1619 return m 1620 1621 1622def _calculate_iomin_between_pred_masks(masks, boxes, eps=1e-6): 1623 overlap_m = _overlap_matrix(boxes) 1624 1625 # Flatten spatial dimensions: (N, H*W) or (N, D*H*W) 1626 N = masks.shape[0] 1627 masks_flat = masks.reshape(N, -1).float() 1628 1629 # Per-mask area 1630 areas = masks_flat.sum(dim=1) # (N,) 1631 1632 # Pairwise intersections via matrix multiplication 1633 # inter[i, j] = sum_k masks_flat[i, k] * masks_flat[j, k] 1634 inter = masks_flat @ masks_flat.t() # (N, N) 1635 1636 # Denominator: min area of the two masks 1637 min_areas = torch.minimum(areas[:, None], areas[None, :]) # (N, N) 1638 1639 # IoMin = intersection / min(area_i, area_j) 1640 iomin = inter / (min_areas + eps) 1641 1642 # Set elements without any overlap explicitly to zero. 1643 iomin[~overlap_m] = 0 1644 return iomin 1645 1646 1647def _batched_mask_nms(masks, boxes, scores, nms_thresh, intersection_over_min): 1648 boxes = ( 1649 boxes.detach() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes) 1650 ).cpu() 1651 scores = ( 1652 scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores) 1653 ).cpu() 1654 masks = ( 1655 masks.detach() if isinstance(masks, torch.Tensor) else torch.tensor(masks) 1656 ).cpu() 1657 1658 if intersection_over_min: 1659 iou_matrix = _calculate_iomin_between_pred_masks(masks, boxes) 1660 else: 1661 iou_matrix = _calculate_ious_between_pred_masks(masks, boxes) 1662 sorted_indices = torch.argsort(scores, descending=True) 1663 1664 keep = [] 1665 while len(sorted_indices) > 0: 1666 i = sorted_indices[0] 1667 keep.append(i) 1668 1669 if len(sorted_indices) == 1: 1670 break 1671 1672 iou_values = iou_matrix[i, sorted_indices[1:]] 1673 mask = iou_values <= nms_thresh 1674 sorted_indices = sorted_indices[1:][mask] 1675 1676 return torch.tensor(keep) 1677 1678 1679def _xywh_to_xyxy(boxes): 1680 boxes = boxes.clone() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes) 1681 boxes = boxes.to(torch.float32) 1682 boxes[:, 2] += boxes[:, 0] 1683 boxes[:, 3] += boxes[:, 1] 1684 return boxes 1685 1686 1687def _infer_tiled_shape(predictions): 1688 shape = [0, 0] 1689 for pred in predictions: 1690 bbox, global_bbox = pred["bbox"], pred["global_bbox"] 1691 offset = (global_bbox[0] - bbox[0], global_bbox[1] - bbox[1]) 1692 mask_shape = pred["segmentation"].shape 1693 shape[0] = max(shape[0], offset[1] + mask_shape[0]) 1694 shape[1] = max(shape[1], offset[0] + mask_shape[1]) 1695 return tuple(shape) 1696 1697 1698def _calculate_tiled_mask_overlap_matrix(masks, boxes, global_boxes, intersection_over_min): 1699 n_masks = len(masks) 1700 overlap_scores = torch.zeros((n_masks, n_masks)) 1701 overlap_scores.fill_diagonal_(1) 1702 1703 boxes = ( 1704 boxes.detach().cpu().to(dtype=torch.long) 1705 if isinstance(boxes, torch.Tensor) else torch.tensor(boxes, dtype=torch.long) 1706 ) 1707 global_boxes = ( 1708 global_boxes.detach().cpu().to(dtype=torch.long) 1709 if isinstance(global_boxes, torch.Tensor) else torch.tensor(global_boxes, dtype=torch.long) 1710 ) 1711 global_boxes_xyxy = _xywh_to_xyxy(global_boxes).to(torch.long) 1712 overlap_m = _overlap_matrix(global_boxes_xyxy) 1713 masks = [mask.detach().cpu() if isinstance(mask, torch.Tensor) else torch.tensor(mask) for mask in masks] 1714 areas = torch.tensor([mask.sum() for mask in masks], dtype=torch.float32) 1715 1716 for i in range(n_masks): 1717 js = torch.where(overlap_m[i])[0] 1718 js_half = js[js > i] 1719 if len(js_half) == 0: 1720 continue 1721 1722 offset_i = global_boxes[i, :2] - boxes[i, :2] 1723 for j in js_half: 1724 offset_j = global_boxes[j, :2] - boxes[j, :2] 1725 overlap = [ 1726 max(global_boxes_xyxy[i, 0], global_boxes_xyxy[j, 0]), 1727 max(global_boxes_xyxy[i, 1], global_boxes_xyxy[j, 1]), 1728 min(global_boxes_xyxy[i, 2], global_boxes_xyxy[j, 2]), 1729 min(global_boxes_xyxy[i, 3], global_boxes_xyxy[j, 3]), 1730 ] 1731 1732 mask_i = masks[i][ 1733 overlap[1] - offset_i[1]:overlap[3] - offset_i[1], 1734 overlap[0] - offset_i[0]:overlap[2] - offset_i[0], 1735 ] 1736 mask_j = masks[j][ 1737 overlap[1] - offset_j[1]:overlap[3] - offset_j[1], 1738 overlap[0] - offset_j[0]:overlap[2] - offset_j[0], 1739 ] 1740 intersection = torch.logical_and(mask_i, mask_j).sum() 1741 min_area = torch.minimum(areas[i], areas[j]) 1742 denominator = min_area if intersection_over_min else areas[i] + areas[j] - intersection 1743 overlap_scores[i, j] = intersection / denominator 1744 1745 overlap_scores = overlap_scores + overlap_scores.T 1746 overlap_scores.fill_diagonal_(1) 1747 return overlap_scores 1748 1749 1750def _batched_tiled_mask_nms(masks, boxes, global_boxes, scores, nms_thresh, intersection_over_min): 1751 scores = ( 1752 scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores) 1753 ).cpu() 1754 1755 iou_matrix = _calculate_tiled_mask_overlap_matrix(masks, boxes, global_boxes, intersection_over_min) 1756 sorted_indices = torch.argsort(scores, descending=True) 1757 1758 keep = [] 1759 while len(sorted_indices) > 0: 1760 i = sorted_indices[0] 1761 keep.append(i) 1762 1763 if len(sorted_indices) == 1: 1764 break 1765 1766 iou_values = iou_matrix[i, sorted_indices[1:]] 1767 mask = iou_values <= nms_thresh 1768 sorted_indices = sorted_indices[1:][mask] 1769 1770 return torch.tensor(keep) 1771 1772 1773def mask_data_to_segmentation( 1774 masks: List[Dict[str, Any]], 1775 shape: Optional[Tuple[int, int]] = None, 1776 min_object_size: int = 0, 1777 max_object_size: Optional[int] = None, 1778 label_masks: bool = True, 1779 with_background: bool = False, 1780 merge_exclusively: bool = True, 1781) -> np.ndarray: 1782 """Convert the output of the automatic mask generation to an instance segmentation. 1783 1784 Args: 1785 masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`, 1786 or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask. 1787 shape: The shape of the output segmentation. If None, it will be derived from the mask input. 1788 If the mask where predicted with tiling then the shape must be given. 1789 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 1790 max_object_size: The maximal size of an object in pixels. 1791 label_masks: Whether to apply connected components to the result before removing small objects. 1792 By default, set to 'True'. 1793 with_background: Whether to remove the largest object, which often covers the background for AMG. 1794 merge_exclusively: Whether to exclude previous merged masks from merging. 1795 1796 Returns: 1797 The instance segmentation. 1798 """ 1799 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 1800 if shape is None: 1801 shape = next(iter(masks))["segmentation"].shape 1802 segmentation = np.zeros(shape, dtype="uint32") 1803 1804 def require_numpy(mask): 1805 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 1806 1807 seg_id = 1 1808 for mask_data in masks: 1809 area = mask_data["area"] 1810 if (area < min_object_size) or (max_object_size is not None and area > max_object_size): 1811 continue 1812 1813 this_mask = require_numpy(mask_data["segmentation"]) 1814 this_seg_id = mask_data.get("seg_id", seg_id) 1815 if "global_bbox" in mask_data: 1816 bb = mask_data["bbox"] 1817 bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]] 1818 global_bb = mask_data["global_bbox"] 1819 global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]] 1820 if merge_exclusively: 1821 this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0) 1822 else: 1823 this_mask = this_mask[bb] 1824 segmentation[global_bb][this_mask] = this_seg_id 1825 else: 1826 if merge_exclusively: 1827 this_mask = np.logical_and(this_mask, segmentation == 0) 1828 segmentation[this_mask] = this_seg_id 1829 seg_id = this_seg_id + 1 1830 1831 block_shape = (512, 512) 1832 if label_masks: 1833 segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype) 1834 segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape) 1835 segmentation = segmentation_cc 1836 1837 seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape) 1838 filter_ids = seg_ids[sizes < min_object_size] 1839 if with_background: 1840 bg_id = seg_ids[np.argmax(sizes)] 1841 filter_ids = np.concatenate([filter_ids, [bg_id]]) 1842 1843 filter_mask = np.zeros(segmentation.shape, dtype="bool") 1844 filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape) 1845 segmentation[filter_mask] = 0 1846 parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0] 1847 1848 return segmentation 1849 1850 1851def apply_nms( 1852 predictions: List[Dict[str, Any]], 1853 min_size: int, 1854 shape: Optional[Tuple[int, int]] = None, 1855 perform_box_nms: bool = False, 1856 nms_thresh: float = 0.9, 1857 max_size: Optional[int] = None, 1858 intersection_over_min: bool = False, 1859) -> np.ndarray: 1860 """Apply non-maximum suppression to mask predictions from a segment anything model. 1861 1862 Args: 1863 predictions: The mask predictions from SAM. 1864 min_size: The minimum mask size to keep in the output. 1865 shape: The shape of the output segmentation. 1866 For tiled predictions this is inferred from the tile-local mask shapes if it is not passed. 1867 perform_box_nms: Whether to perform NMS on the box coordinates or on the masks. 1868 nms_thresh: The threshold for filtering out objects in NMS. 1869 max_size: The maximum mask size to keep in the output. 1870 intersection_over_min: Whether to perform intersection over the minimum overlap shape 1871 or to perform intersection over union. 1872 1873 Returns: 1874 The segmentation obtained from merging the masks left after NMS. 1875 """ 1876 # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from 1877 # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account. 1878 is_tiled = "global_bbox" in predictions[0] 1879 if is_tiled and shape is None: 1880 shape = _infer_tiled_shape(predictions) 1881 1882 masks = [pred["segmentation"] for pred in predictions] 1883 nms_masks = None if is_tiled else torch.cat([mask[None] for mask in masks], dim=0) 1884 data = amg_utils.MaskData(masks=masks, iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions])) 1885 data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions])) 1886 data["area"] = [int(mask.sum()) for mask in data["masks"]] 1887 data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions]) 1888 if is_tiled: 1889 data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions])) 1890 1891 if min_size > 0: 1892 keep_by_size = torch.tensor( 1893 [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long, 1894 ) 1895 data.filter(keep_by_size) 1896 if nms_masks is not None: 1897 nms_masks = nms_masks[keep_by_size] 1898 1899 if max_size is not None: 1900 keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size]) 1901 data.filter(keep_by_size) 1902 if nms_masks is not None: 1903 nms_masks = nms_masks[keep_by_size] 1904 1905 if len(data["masks"]) == 0: 1906 if shape is None: 1907 shape = predictions[0]["segmentation"].shape 1908 return np.zeros(shape, dtype="uint32") 1909 1910 scores = data["iou_preds"] * data["stability_scores"] 1911 boxes = _xywh_to_xyxy(data["global_boxes"] if is_tiled else data["boxes"]) 1912 if perform_box_nms: 1913 assert not intersection_over_min # not implemented 1914 keep_by_nms = batched_nms( 1915 boxes, 1916 scores, 1917 torch.zeros_like(data["boxes"][:, 0]), # categories 1918 iou_threshold=nms_thresh, 1919 ) 1920 elif is_tiled: 1921 keep_by_nms = _batched_tiled_mask_nms( 1922 masks=data["masks"], 1923 boxes=data["boxes"], 1924 global_boxes=data["global_boxes"], 1925 scores=scores, 1926 nms_thresh=nms_thresh, 1927 intersection_over_min=intersection_over_min, 1928 ) 1929 else: 1930 keep_by_nms = _batched_mask_nms( 1931 masks=nms_masks, 1932 boxes=boxes, 1933 scores=scores, 1934 nms_thresh=nms_thresh, 1935 intersection_over_min=intersection_over_min, 1936 ) 1937 data.filter(keep_by_nms) 1938 1939 if is_tiled: 1940 mask_data = [ 1941 {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box} 1942 for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"]) 1943 ] 1944 else: 1945 mask_data = [ 1946 {"segmentation": mask, "area": area, "bbox": box} 1947 for mask, area, box in zip(data["masks"], data["area"], data["boxes"]) 1948 ] 1949 1950 if shape is None: 1951 shape = predictions[0]["segmentation"].shape 1952 if mask_data: 1953 segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size) 1954 else: # In case all objects have been filtered out due to size filtering. 1955 segmentation = np.zeros(shape, dtype="uint32") 1956 1957 return segmentation
64def get_cache_directory() -> None: 65 """Get micro-sam cache directory location. 66 67 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 68 """ 69 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 70 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 71 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
79def microsam_cachedir() -> None: 80 """Return the micro-sam cache directory. 81 82 Returns the top level cache directory for micro-sam models and sample data. 83 84 Every time this function is called, we check for any user updates made to 85 the MICROSAM_CACHEDIR os environment variable since the last time. 86 """ 87 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 88 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
91def models(): 92 """Return the segmentation models registry. 93 94 We recreate the model registry every time this function is called, 95 so any user changes to the default micro-sam cache directory location 96 are respected. 97 """ 98 99 # We use xxhash to compute the hash of the models, see 100 # https://github.com/computational-cell-analytics/micro-sam/issues/283 101 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 102 # To generate the xxh128 hash: 103 # xxh128sum filename 104 encoder_registry = { 105 # The default segment anything models: 106 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 107 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 108 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 109 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 110 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 111 # The current version of our models in the modelzoo. 112 # LM generalist models: 113 "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d", 114 "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186", 115 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 116 # EM models: 117 "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d", 118 "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438", 119 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 120 # Histopathology models: 121 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 122 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 123 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 124 # Medical Imaging models: 125 "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92", 126 } 127 # Additional decoders for instance segmentation. 128 decoder_registry = { 129 # LM generalist models: 130 "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a", 131 "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba", 132 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 133 # EM models: 134 "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294", 135 "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2", 136 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 137 # Histopathology models: 138 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 139 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 140 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 141 # Medical Imaging models: 142 "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed", 143 } 144 registry = {**encoder_registry, **decoder_registry} 145 146 encoder_urls = { 147 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 148 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 149 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 150 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 151 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt", 152 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt", 153 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 154 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt", # noqa 155 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt", # noqa 156 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 157 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 158 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 159 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 160 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download", 161 } 162 163 decoder_urls = { 164 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt", # noqa 165 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt", # noqa 166 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 167 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt", # noqa 168 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt", # noqa 169 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 170 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 171 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 172 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 173 "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download", 174 } 175 urls = {**encoder_urls, **decoder_urls} 176 177 models = pooch.create( 178 path=os.path.join(microsam_cachedir(), "models"), 179 base_url="", 180 registry=registry, 181 urls=urls, 182 ) 183 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
205def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 206 """Get the torch device. 207 208 If no device is passed the default device for your system is used. 209 Else it will be checked if the device you have passed is supported. 210 211 Args: 212 device: The input device. By default, selects the best available device supports. 213 214 Returns: 215 The device. 216 """ 217 if device is None or device == "auto": 218 device = _get_default_device() 219 else: 220 device_type = device if isinstance(device, str) else device.type 221 if device_type.lower() == "cuda": 222 if not torch.cuda.is_available(): 223 raise RuntimeError("PyTorch CUDA backend is not available.") 224 elif device_type.lower() == "mps": 225 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 226 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 227 elif device_type.lower() == "cpu": 228 pass # cpu is always available 229 else: 230 raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.") 231 232 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device. By default, selects the best available device supports.
Returns:
The device.
319def get_sam_model( 320 model_type: str = _DEFAULT_MODEL, 321 device: Optional[Union[str, torch.device]] = None, 322 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 323 return_sam: bool = False, 324 return_state: bool = False, 325 peft_kwargs: Optional[Dict] = None, 326 flexible_load_checkpoint: bool = False, 327 progress_bar_factory: Optional[Callable] = None, 328 decoder_path: Optional[Union[str, os.PathLike]] = None, 329 **model_kwargs, 330) -> SamPredictor: 331 r"""Get the Segment Anything Predictor. 332 333 This function will download the required model or load it from the cached weight file. 334 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 335 The name of the requested model can be set via `model_type`. 336 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 337 for an overview of the available models 338 339 Alternatively this function can also load a model from weights stored in a local filepath. 340 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 341 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 342 a SAM model with vit_b encoder. 343 344 By default the models are downloaded to a folder named 'micro_sam/models' 345 inside your default cache directory, eg: 346 * Mac: ~/Library/Caches/<AppName> 347 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 348 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 349 See the pooch.os_cache() documentation for more details: 350 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 351 352 Args: 353 model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. 354 To get a list of all available model names you can call `micro_sam.util.get_model_names`. 355 device: The device for the model. If 'None' is provided, will use GPU if available. 356 checkpoint_path: The path to a file with weights that should be used instead of using the 357 weights corresponding to `model_type`. If given, `model_type` must match the architecture 358 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 359 then `model_type` must be given as 'vit_b'. 360 return_sam: Return the sam model object as well as the predictor. By default, set to 'False'. 361 return_state: Return the unpickled checkpoint state. By default, set to 'False'. 362 peft_kwargs: Keyword arguments for th PEFT wrapper class. 363 If passed 'None', it does not initialize any parameter efficient finetuning. 364 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 365 By default, set to 'False'. 366 progress_bar_factory: A function to create a progress bar for the model download. 367 decoder_path: Optional path to weights for a segmentation decoder. If given and 368 `return_state=True`, the decoder state is added to the returned state as 369 'decoder_state'. This can be used to provide decoder-only weights that are 370 separate from the encoder checkpoint. 371 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 372 373 Returns: 374 The Segment Anything predictor. 375 """ 376 device = get_device(device) 377 378 # We support passing a local filepath to a checkpoint. 379 # In this case we do not download any weights but just use the local weight file, 380 # as it is, without copying it over anywhere or checking it's hashes. 381 382 # checkpoint_path has not been passed, we download a known model and derive the correct 383 # URL from the model_type. If the model_type is invalid pooch will raise an error. 384 _provided_checkpoint_path = checkpoint_path is not None 385 if checkpoint_path is None: 386 checkpoint_path, model_hash, downloaded_decoder_path = _download_sam_model(model_type, progress_bar_factory) 387 if decoder_path is None: 388 decoder_path = downloaded_decoder_path 389 390 # checkpoint_path has been passed, we use it instead of downloading a model. 391 else: 392 # Check if the file exists and raise an error otherwise. 393 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 394 # (If it isn't the model creation will fail below.) 395 if not os.path.exists(checkpoint_path): 396 raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.") 397 model_hash = _compute_hash(checkpoint_path) 398 399 if decoder_path is not None and not os.path.exists(decoder_path): 400 raise ValueError(f"Decoder checkpoint at '{decoder_path}' could not be found.") 401 402 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 403 # before calling sam_model_registry. 404 abbreviated_model_type = model_type[:5] 405 if abbreviated_model_type not in _MODEL_TYPES: 406 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 407 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 408 raise RuntimeError( 409 "'mobile_sam' is required for the vit-tiny. " 410 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 411 ) 412 413 state, model_state = _load_checkpoint(checkpoint_path) 414 415 if _provided_checkpoint_path: 416 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 417 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 418 from micro_sam.models.build_sam import _validate_model_type 419 _provided_model_type = _validate_model_type(model_state) 420 421 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 422 # Otherwise replace 'abbreviated_model_type' with the later. 423 if abbreviated_model_type != _provided_model_type: 424 # Printing the message below to avoid any filtering of warnings on user's end. 425 print( 426 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 427 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 428 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 429 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 430 ) 431 432 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 433 abbreviated_model_type = _provided_model_type 434 435 # Whether to update parameters necessary to initialize the model 436 if model_kwargs: # Checks whether model_kwargs have been provided or not 437 if abbreviated_model_type == "vit_t": 438 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 439 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 440 441 else: 442 sam = sam_model_registry[abbreviated_model_type]() 443 444 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 445 # Overwrites the SAM model by freezing the backbone and allow PEFT. 446 if peft_kwargs and isinstance(peft_kwargs, dict): 447 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 448 peft_kwargs.pop("quantize", None) 449 450 if abbreviated_model_type == "vit_t": 451 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 452 453 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 454 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 455 if flexible_load_checkpoint: 456 sam = _handle_checkpoint_loading(sam, model_state) 457 else: 458 sam.load_state_dict(model_state) 459 sam.to(device=device) 460 461 predictor = SamPredictor(sam) 462 predictor.model_type = abbreviated_model_type 463 predictor._hash = model_hash 464 predictor.model_name = model_type 465 predictor.checkpoint_path = checkpoint_path 466 467 # Add the decoder to the state if we have one and if the state is returned. 468 if decoder_path is not None and return_state: 469 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 470 471 if return_sam and return_state: 472 return predictor, sam, state 473 if return_sam: 474 return predictor, sam 475 if return_state: 476 return predictor, state 477 return predictor
Get the Segment Anything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
To get a list of all available model names you can call
micro_sam.util.get_model_names. - device: The device for the model. If 'None' is provided, will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type. If given,model_typemust match the architecture corresponding to the weight file. e.g. if you use weights for SAM withvit_bencoder thenmodel_typemust be given as 'vit_b'. - return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
- return_state: Return the unpickled checkpoint state. By default, set to 'False'.
- peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
- progress_bar_factory: A function to create a progress bar for the model download.
- decoder_path: Optional path to weights for a segmentation decoder. If given and
return_state=True, the decoder state is added to the returned state as 'decoder_state'. This can be used to provide decoder-only weights that are separate from the encoder checkpoint. - model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:
The Segment Anything predictor.
513def export_custom_sam_model( 514 checkpoint_path: Union[str, os.PathLike], 515 model_type: str, 516 save_path: Union[str, os.PathLike], 517 with_segmentation_decoder: bool = False, 518 prefix: str = "sam.", 519) -> None: 520 """Export a finetuned Segment Anything Model to the standard model format. 521 522 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 523 524 Args: 525 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 526 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 527 save_path: Where to save the exported model. 528 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 529 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 530 prefix: The prefix to remove from the model parameter keys. 531 """ 532 state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path) 533 model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]) 534 535 # Store the 'decoder_state' as well, if desired. 536 if with_segmentation_decoder: 537 if "decoder_state" not in state: 538 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 539 decoder_state = state["decoder_state"] 540 save_state = {"model_state": model_state, "decoder_state": decoder_state} 541 else: 542 save_state = model_state 543 544 torch.save(save_state, save_path)
Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
- with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
- prefix: The prefix to remove from the model parameter keys.
547def export_custom_qlora_model( 548 checkpoint_path: Optional[Union[str, os.PathLike]], 549 finetuned_path: Union[str, os.PathLike], 550 model_type: str, 551 save_path: Union[str, os.PathLike], 552) -> None: 553 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 554 555 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 556 557 Args: 558 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 559 finetuned_path: The path to the new finetuned model, using QLoRA. 560 model_type: The Segment Anything Model type corresponding to the checkpoint. 561 save_path: Where to save the exported model. 562 """ 563 # Step 1: Get the base SAM model: used to start finetuning from. 564 _, sam = get_sam_model( 565 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 566 ) 567 568 # Step 2: Load the QLoRA-style finetuned model. 569 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 570 571 # Step 3: Identify LoRA layers from QLoRA model. 572 # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers. 573 # - then copy the LoRA layers from the QLoRA model to the new state dict 574 updated_model_state = {} 575 576 modified_attn_layers = set() 577 modified_mlp_layers = set() 578 579 for k, v in ft_model_state.items(): 580 if "blocks." in k: 581 layer_id = int(k.split("blocks.")[1].split(".")[0]) 582 if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1: 583 modified_attn_layers.add(layer_id) 584 updated_model_state[k] = v 585 if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1: 586 modified_mlp_layers.add(layer_id) 587 updated_model_state[k] = v 588 589 # Step 4: Next, we get all the remaining parameters from the base SAM model. 590 for k, v in sam.state_dict().items(): 591 if "blocks." in k: 592 layer_id = int(k.split("blocks.")[1].split(".")[0]) 593 if k.find("attn.qkv.") != -1: 594 if layer_id in modified_attn_layers: # We have LoRA in QKV layers, so we need to modify the key 595 k = k.replace("qkv", "qkv.qkv_proj") 596 elif k.find("mlp") != -1 and k.find("image_encoder") != -1: 597 if layer_id in modified_mlp_layers: # We have LoRA in MLP layers, so we need to modify the key 598 k = k.replace("mlp.", "mlp.mlp_layer.") 599 updated_model_state[k] = v 600 601 # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff) 602 ft_state['model_state'] = updated_model_state 603 604 # Step 6: Store the new "state" to "save_path" 605 torch.save(ft_state, save_path)
Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.
Arguments:
- checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
- finetuned_path: The path to the new finetuned model, using QLoRA.
- model_type: The Segment Anything Model type corresponding to the checkpoint.
- save_path: Where to save the exported model.
1134def precompute_image_embeddings( 1135 predictor: SamPredictor, 1136 input_: np.ndarray, 1137 save_path: Optional[Union[str, os.PathLike]] = None, 1138 lazy_loading: bool = False, 1139 ndim: Optional[int] = None, 1140 tile_shape: Optional[Tuple[int, int]] = None, 1141 halo: Optional[Tuple[int, int]] = None, 1142 verbose: bool = True, 1143 batch_size: int = 1, 1144 mask: Optional[np.typing.ArrayLike] = None, 1145 pbar_init: Optional[callable] = None, 1146 pbar_update: Optional[callable] = None, 1147) -> ImageEmbeddings: 1148 """Compute the image embeddings (output of the encoder) for the input. 1149 1150 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 1151 1152 Args: 1153 predictor: The Segment Anything predictor. 1154 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 1155 save_path: Path to save the embeddings in a zarr container. 1156 By default, set to 'None', i.e. the computed embeddings will not be stored locally. 1157 lazy_loading: Whether to load all embeddings into memory or return an 1158 object to load them on demand when required. This only has an effect if 'save_path' is given 1159 and if the input is 3 dimensional. By default, set to 'False'. 1160 ndim: The dimensionality of the data. If not given will be deduced from the input data. 1161 By default, set to 'None', i.e. will be computed from the provided `input_`. 1162 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 1163 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 1164 verbose: Whether to be verbose in the computation. By default, set to 'True'. 1165 batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'. 1166 mask: An optional mask to define areas that are ignored in the computation. 1167 The mask will be used within tiled embedding computation and tiles that don't contain any foreground 1168 in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings. 1169 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 1170 Can be used together with pbar_update to handle napari progress bar in other thread. 1171 To enable using this function within a threadworker. 1172 pbar_update: Callback to update an external progress bar. 1173 1174 Returns: 1175 The image embeddings. 1176 """ 1177 ndim = input_.ndim if ndim is None else ndim 1178 1179 # Handle the embedding save_path. 1180 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1181 if save_path is None: 1182 f = zarr.group() 1183 1184 # We have a save path and it already exists. Embeddings will be loaded from it, 1185 # check that the saved embeddings in there match the parameters of the function call. 1186 elif os.path.exists(save_path): 1187 f = zarr.open(save_path, mode="a") 1188 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1189 1190 # We have a save path and it does not exist yet. Create the zarr file to which the 1191 # embeddings will then be saved. 1192 else: 1193 f = zarr.open(save_path, mode="a") 1194 1195 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1196 1197 if ndim == 2 and tile_shape is None: 1198 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1199 elif ndim == 2 and tile_shape is not None: 1200 embeddings = _compute_tiled_2d( 1201 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1202 ) 1203 elif ndim == 3 and tile_shape is None: 1204 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1205 elif ndim == 3 and tile_shape is not None: 1206 embeddings = _compute_tiled_3d( 1207 input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask 1208 ) 1209 else: 1210 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1211 1212 pbar_close() 1213 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The Segment Anything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
By default, set to 'None', i.e. will be computed from the provided
input_. - tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Whether to be verbose in the computation. By default, set to 'True'.
- batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
- mask: An optional mask to define areas that are ignored in the computation. The mask will be used within tiled embedding computation and tiles that don't contain any foreground in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
1216def set_precomputed( 1217 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1218) -> SamPredictor: 1219 """Set the precomputed image embeddings for a predictor. 1220 1221 Args: 1222 predictor: The Segment Anything predictor. 1223 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1224 i: Index for the image data. Required if `image` has three spatial dimensions 1225 or a time dimension and two spatial dimensions. 1226 tile_id: Index for the tile. This is required if the embeddings are tiled. 1227 1228 Returns: 1229 The predictor with set features. 1230 """ 1231 if tile_id is not None: 1232 tile_features = image_embeddings["features"][str(tile_id)] 1233 tile_image_embeddings = { 1234 "features": tile_features, 1235 "input_size": tile_features.attrs["input_size"], 1236 "original_size": tile_features.attrs["original_size"] 1237 } 1238 return set_precomputed(predictor, tile_image_embeddings, i=i) 1239 1240 device = predictor.device 1241 features = image_embeddings["features"] 1242 assert features.ndim in (4, 5), f"{features.ndim}" 1243 if features.ndim == 5 and i is None: 1244 raise ValueError("The data is 3D so an index i is needed.") 1245 elif features.ndim == 4 and i is not None: 1246 raise ValueError("The data is 2D so an index is not needed.") 1247 1248 if i is None: 1249 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1250 torch.from_numpy(features[:]).to(device) 1251 else: 1252 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1253 torch.from_numpy(features[i]).to(device) 1254 1255 predictor.original_size = image_embeddings["original_size"] 1256 predictor.input_size = image_embeddings["input_size"] 1257 predictor.is_image_set = True 1258 1259 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings. - i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
1267def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1268 """Compute the intersection over union of two masks. 1269 1270 Args: 1271 mask1: The first mask. 1272 mask2: The second mask. 1273 1274 Returns: 1275 The intersection over union of the two masks. 1276 """ 1277 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1278 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1279 eps = 1e-7 1280 iou = float(overlap) / (float(union) + eps) 1281 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
1284def get_centers_and_bounding_boxes( 1285 segmentation: np.ndarray, mode: str = "v" 1286) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1287 """Returns the center coordinates of the foreground instances in the ground-truth. 1288 1289 Args: 1290 segmentation: The segmentation. 1291 mode: Determines the functionality used for computing the centers. 1292 If 'v', the point of maximal distance to the object boundary is used as center. 1293 This center is guaranteed to lie inside the object, also for concave shapes. 1294 If 'p' the object's centroids computed by skimage are used. 1295 1296 Returns: 1297 A dictionary that maps object ids to the corresponding centroid. 1298 A dictionary that maps object_ids to the corresponding bounding box. 1299 """ 1300 assert mode in ["p", "v"], "Choose either 'p' for regionprops centroids or 'v' for distance-based centers" 1301 1302 properties = regionprops(segmentation) 1303 1304 if mode == "p": 1305 center_coordinates = {prop.label: prop.centroid for prop in properties} 1306 elif mode == "v": 1307 # Use the point of maximal distance to the object boundary as the center. 1308 # In contrast to the centroid, this point is guaranteed to lie inside the object, 1309 # also for concave shapes. This replaces vigra.filters.eccentricityCenters. 1310 # Compute the boundaries and a single distance transform for the whole 1311 # segmentation, instead of one distance transform per object. 1312 ndim = segmentation.ndim 1313 # Pad so objects touching the image border also get a boundary there, 1314 # matching a per-object padded distance transform. 1315 padded = np.pad(segmentation, 1) 1316 boundaries = find_boundaries(padded, mode="inner") 1317 distances = distance_transform(boundaries == 0) 1318 1319 center_coordinates = {} 1320 for prop in properties: 1321 bbox = prop.bbox 1322 # Slice the global distance field to this object's bbox (shifted by the 1323 # pad of 1) and restrict the argmax to the object's own pixels. 1324 region = distances[tuple(slice(b + 1, e + 1) for b, e in zip(bbox[:ndim], bbox[ndim:]))] 1325 masked = np.where(prop.image, region, -1.0) 1326 center_local = np.unravel_index(int(np.argmax(masked)), masked.shape) 1327 center_coordinates[prop.label] = tuple(int(c + o) for c, o in zip(center_local, bbox[:ndim])) 1328 1329 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1330 1331 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1332 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers. If 'v', the point of maximal distance to the object boundary is used as center. This center is guaranteed to lie inside the object, also for concave shapes. If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
1335def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1336 """Helper function to load image data from file. 1337 1338 Args: 1339 path: The filepath to the image data. 1340 key: The internal filepath for complex data formats like hdf5. 1341 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1342 1343 Returns: 1344 The image data. 1345 """ 1346 if key is None: 1347 image_data = imageio.imread(path) 1348 else: 1349 with open_file(path, mode="r") as f: 1350 image_data = f[key] 1351 if not lazy_loading: 1352 image_data = image_data[:] 1353 1354 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
1357def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1358 """Convert the segmentation to one-hot encoded masks. 1359 1360 Args: 1361 segmentation: The segmentation. 1362 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1363 By default, computes the number of ids from the provided `segmentation` masks. 1364 1365 Returns: 1366 The one-hot encoded masks. 1367 """ 1368 masks = segmentation.copy() 1369 if segmentation_ids is None: 1370 n_ids = int(segmentation.max()) 1371 1372 else: 1373 msg = "No foreground objects were found." 1374 if len(segmentation_ids) == 0: # The list should not be completely empty. 1375 raise RuntimeError(msg) 1376 1377 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1378 raise RuntimeError(msg) 1379 1380 # the segmentation ids have to be sorted 1381 segmentation_ids = np.sort(segmentation_ids) 1382 1383 # set the non selected objects to zero and relabel sequentially 1384 masks[~np.isin(masks, segmentation_ids)] = 0 1385 masks = relabel_sequential(masks)[0] 1386 n_ids = len(segmentation_ids) 1387 1388 masks = torch.from_numpy(masks) 1389 1390 one_hot_shape = (n_ids + 1,) + masks.shape 1391 masks = masks.unsqueeze(0) # add dimension to scatter 1392 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1393 1394 # add the extra singleton dimension to get shape NUM_OBJECTS x 1 x H x W 1395 masks = masks.unsqueeze(1) 1396 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
By default, computes the number of ids from the provided
segmentationmasks.
Returns:
The one-hot encoded masks.
1399def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1400 """Get a suitable block shape for chunking a given shape. 1401 1402 The primary use for this is determining chunk sizes for 1403 zarr arrays or block shapes for parallelization. 1404 1405 Args: 1406 shape: The image or volume shape. 1407 1408 Returns: 1409 The block shape. 1410 """ 1411 ndim = len(shape) 1412 if ndim == 2: 1413 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1414 elif ndim == 3: 1415 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1416 else: 1417 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1418 1419 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.
1422def micro_sam_info() -> None: 1423 """Display μSAM information using a rich console.""" 1424 import psutil 1425 import platform 1426 import argparse 1427 from rich import progress 1428 from rich.panel import Panel 1429 from rich.table import Table 1430 from rich.console import Console 1431 1432 import torch 1433 import micro_sam 1434 1435 parser = argparse.ArgumentParser(description="μSAM Information Booth") 1436 parser.add_argument( 1437 "--download", nargs="+", metavar=("WHAT", "KIND"), 1438 help="Downloads the pretrained SAM models." 1439 "'--download models' -> downloads all pretrained models; " 1440 "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; " 1441 "'--download model/models vit_b_lm' -> downloads a single specified model." 1442 ) 1443 args = parser.parse_args() 1444 1445 # Open up a new console. 1446 console = Console() 1447 1448 # The header for information CLI. 1449 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1450 console.print("-" * console.width) 1451 1452 # μSAM version panel. 1453 console.print( 1454 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1455 ) 1456 1457 # The documentation link panel. 1458 console.print( 1459 Panel( 1460 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1461 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1462 ) 1463 ) 1464 1465 # The publication panel. 1466 console.print( 1467 Panel( 1468 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1469 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1470 ) 1471 ) 1472 1473 # Creating a cache directory when users' run `micro_sam.info`. 1474 cache_dir = get_cache_directory() 1475 os.makedirs(cache_dir, exist_ok=True) 1476 1477 # The cache directory panel. 1478 console.print( 1479 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory") 1480 ) 1481 1482 # We have a simple versioning logic here (which is what I'll follow here for mapping model versions). 1483 available_models = [] 1484 for model_name, model_path in models().urls.items(): # We filter out the decoder models. 1485 if model_name.endswith("decoder"): 1486 continue 1487 1488 if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path: # Valid v1 SAM models. 1489 available_models.append(model_name) 1490 1491 if "https://owncloud.gwdg.de/" in model_path: # Our own hosted models (in their v1 mode quite often) 1492 if model_name == "vit_t": # MobileSAM model. 1493 available_models.append(model_name) 1494 else: 1495 available_models.append(f"{model_name} (v1)") 1496 1497 # Now for our models, the BioImageIO ModelZoo upload structure is such that: 1498 # '/1/files' corresponds to v2 models. 1499 # '/1.1/files' corresponds to v3 models. 1500 # '/1.2/files' corresponds to v4 models. 1501 if "/1/files" in model_path: 1502 available_models.append(f"{model_name} (v2)") 1503 if "/1.1/files" in model_path: 1504 available_models.append(f"{model_name} (v3)") 1505 if "/1.2/files" in model_path: 1506 available_models.append(f"{model_name} (v4)") 1507 1508 model_list = "\n".join(available_models) 1509 1510 # The available models panel. 1511 console.print( 1512 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1513 ) 1514 1515 # The system information table. 1516 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1517 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1518 table.add_column("Property") 1519 table.add_column("Value", style="bold #56B4E9") 1520 table.add_row("System", platform.system()) 1521 table.add_row("Node Name", platform.node()) 1522 table.add_row("Release", platform.release()) 1523 table.add_row("Version", platform.version()) 1524 table.add_row("Machine", platform.machine()) 1525 table.add_row("Processor", platform.processor()) 1526 table.add_row("Platform", platform.platform()) 1527 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1528 console.print(table) 1529 1530 # The device information and check for available GPU acceleration. 1531 default_device = _get_default_device() 1532 1533 if default_device == "cuda": 1534 device_index = torch.cuda.current_device() 1535 device_name = torch.cuda.get_device_name(device_index) 1536 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1537 elif default_device == "mps": 1538 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1539 else: 1540 console.print( 1541 Panel( 1542 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1543 title="Device Information" 1544 ) 1545 ) 1546 1547 # The section allowing to download models. 1548 # NOTE: In future, can be extended to download sample data. 1549 if args.download: 1550 download_provided_args = [t.lower() for t in args.download] 1551 mode, *model_types = download_provided_args 1552 1553 if mode not in {"models", "model"}: 1554 console.print(f"[red]Unknown option for --download: {mode}[/]") 1555 return 1556 1557 if mode in ["model", "models"] and not model_types: # If user did not specify, we will download all models. 1558 download_list = available_models 1559 else: 1560 download_list = model_types 1561 incorrect_models = [m for m in download_list if m not in available_models] 1562 if incorrect_models: 1563 console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error")) 1564 return 1565 1566 with progress.Progress( 1567 progress.SpinnerColumn(), 1568 progress.TextColumn("[progress.description]{task.description}"), 1569 progress.BarColumn(bar_width=None), 1570 "[progress.percentage]{task.percentage:>3.0f}%", 1571 progress.TimeRemainingColumn(), 1572 console=console, 1573 ) as prog: 1574 task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list)) 1575 for model_type in download_list: 1576 prog.update(task, description=f"Downloading [cyan]{model_type}[/]…") 1577 _download_sam_model(model_type=model_type) 1578 prog.advance(task) 1579 1580 console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
Display μSAM information using a rich console.
1774def mask_data_to_segmentation( 1775 masks: List[Dict[str, Any]], 1776 shape: Optional[Tuple[int, int]] = None, 1777 min_object_size: int = 0, 1778 max_object_size: Optional[int] = None, 1779 label_masks: bool = True, 1780 with_background: bool = False, 1781 merge_exclusively: bool = True, 1782) -> np.ndarray: 1783 """Convert the output of the automatic mask generation to an instance segmentation. 1784 1785 Args: 1786 masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`, 1787 or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask. 1788 shape: The shape of the output segmentation. If None, it will be derived from the mask input. 1789 If the mask where predicted with tiling then the shape must be given. 1790 min_object_size: The minimal size of an object in pixels. By default, set to '0'. 1791 max_object_size: The maximal size of an object in pixels. 1792 label_masks: Whether to apply connected components to the result before removing small objects. 1793 By default, set to 'True'. 1794 with_background: Whether to remove the largest object, which often covers the background for AMG. 1795 merge_exclusively: Whether to exclude previous merged masks from merging. 1796 1797 Returns: 1798 The instance segmentation. 1799 """ 1800 masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 1801 if shape is None: 1802 shape = next(iter(masks))["segmentation"].shape 1803 segmentation = np.zeros(shape, dtype="uint32") 1804 1805 def require_numpy(mask): 1806 return mask.cpu().numpy() if torch.is_tensor(mask) else mask 1807 1808 seg_id = 1 1809 for mask_data in masks: 1810 area = mask_data["area"] 1811 if (area < min_object_size) or (max_object_size is not None and area > max_object_size): 1812 continue 1813 1814 this_mask = require_numpy(mask_data["segmentation"]) 1815 this_seg_id = mask_data.get("seg_id", seg_id) 1816 if "global_bbox" in mask_data: 1817 bb = mask_data["bbox"] 1818 bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]] 1819 global_bb = mask_data["global_bbox"] 1820 global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]] 1821 if merge_exclusively: 1822 this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0) 1823 else: 1824 this_mask = this_mask[bb] 1825 segmentation[global_bb][this_mask] = this_seg_id 1826 else: 1827 if merge_exclusively: 1828 this_mask = np.logical_and(this_mask, segmentation == 0) 1829 segmentation[this_mask] = this_seg_id 1830 seg_id = this_seg_id + 1 1831 1832 block_shape = (512, 512) 1833 if label_masks: 1834 segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype) 1835 segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape) 1836 segmentation = segmentation_cc 1837 1838 seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape) 1839 filter_ids = seg_ids[sizes < min_object_size] 1840 if with_background: 1841 bg_id = seg_ids[np.argmax(sizes)] 1842 filter_ids = np.concatenate([filter_ids, [bg_id]]) 1843 1844 filter_mask = np.zeros(segmentation.shape, dtype="bool") 1845 filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape) 1846 segmentation[filter_mask] = 0 1847 parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0] 1848 1849 return segmentation
Convert the output of the automatic mask generation to an instance segmentation.
Arguments:
- masks: The outputs generated by
AutomaticMaskGenerator, other classes frommicro_sam.instance_segmentation, or frommicro_sam.inferencefunctions. Only supported for output_mode=binary_mask. - shape: The shape of the output segmentation. If None, it will be derived from the mask input. If the mask where predicted with tiling then the shape must be given.
- min_object_size: The minimal size of an object in pixels. By default, set to '0'.
- max_object_size: The maximal size of an object in pixels.
- label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
- with_background: Whether to remove the largest object, which often covers the background for AMG.
- merge_exclusively: Whether to exclude previous merged masks from merging.
Returns:
The instance segmentation.
1852def apply_nms( 1853 predictions: List[Dict[str, Any]], 1854 min_size: int, 1855 shape: Optional[Tuple[int, int]] = None, 1856 perform_box_nms: bool = False, 1857 nms_thresh: float = 0.9, 1858 max_size: Optional[int] = None, 1859 intersection_over_min: bool = False, 1860) -> np.ndarray: 1861 """Apply non-maximum suppression to mask predictions from a segment anything model. 1862 1863 Args: 1864 predictions: The mask predictions from SAM. 1865 min_size: The minimum mask size to keep in the output. 1866 shape: The shape of the output segmentation. 1867 For tiled predictions this is inferred from the tile-local mask shapes if it is not passed. 1868 perform_box_nms: Whether to perform NMS on the box coordinates or on the masks. 1869 nms_thresh: The threshold for filtering out objects in NMS. 1870 max_size: The maximum mask size to keep in the output. 1871 intersection_over_min: Whether to perform intersection over the minimum overlap shape 1872 or to perform intersection over union. 1873 1874 Returns: 1875 The segmentation obtained from merging the masks left after NMS. 1876 """ 1877 # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from 1878 # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account. 1879 is_tiled = "global_bbox" in predictions[0] 1880 if is_tiled and shape is None: 1881 shape = _infer_tiled_shape(predictions) 1882 1883 masks = [pred["segmentation"] for pred in predictions] 1884 nms_masks = None if is_tiled else torch.cat([mask[None] for mask in masks], dim=0) 1885 data = amg_utils.MaskData(masks=masks, iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions])) 1886 data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions])) 1887 data["area"] = [int(mask.sum()) for mask in data["masks"]] 1888 data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions]) 1889 if is_tiled: 1890 data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions])) 1891 1892 if min_size > 0: 1893 keep_by_size = torch.tensor( 1894 [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long, 1895 ) 1896 data.filter(keep_by_size) 1897 if nms_masks is not None: 1898 nms_masks = nms_masks[keep_by_size] 1899 1900 if max_size is not None: 1901 keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size]) 1902 data.filter(keep_by_size) 1903 if nms_masks is not None: 1904 nms_masks = nms_masks[keep_by_size] 1905 1906 if len(data["masks"]) == 0: 1907 if shape is None: 1908 shape = predictions[0]["segmentation"].shape 1909 return np.zeros(shape, dtype="uint32") 1910 1911 scores = data["iou_preds"] * data["stability_scores"] 1912 boxes = _xywh_to_xyxy(data["global_boxes"] if is_tiled else data["boxes"]) 1913 if perform_box_nms: 1914 assert not intersection_over_min # not implemented 1915 keep_by_nms = batched_nms( 1916 boxes, 1917 scores, 1918 torch.zeros_like(data["boxes"][:, 0]), # categories 1919 iou_threshold=nms_thresh, 1920 ) 1921 elif is_tiled: 1922 keep_by_nms = _batched_tiled_mask_nms( 1923 masks=data["masks"], 1924 boxes=data["boxes"], 1925 global_boxes=data["global_boxes"], 1926 scores=scores, 1927 nms_thresh=nms_thresh, 1928 intersection_over_min=intersection_over_min, 1929 ) 1930 else: 1931 keep_by_nms = _batched_mask_nms( 1932 masks=nms_masks, 1933 boxes=boxes, 1934 scores=scores, 1935 nms_thresh=nms_thresh, 1936 intersection_over_min=intersection_over_min, 1937 ) 1938 data.filter(keep_by_nms) 1939 1940 if is_tiled: 1941 mask_data = [ 1942 {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box} 1943 for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"]) 1944 ] 1945 else: 1946 mask_data = [ 1947 {"segmentation": mask, "area": area, "bbox": box} 1948 for mask, area, box in zip(data["masks"], data["area"], data["boxes"]) 1949 ] 1950 1951 if shape is None: 1952 shape = predictions[0]["segmentation"].shape 1953 if mask_data: 1954 segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size) 1955 else: # In case all objects have been filtered out due to size filtering. 1956 segmentation = np.zeros(shape, dtype="uint32") 1957 1958 return segmentation
Apply non-maximum suppression to mask predictions from a segment anything model.
Arguments:
- predictions: The mask predictions from SAM.
- min_size: The minimum mask size to keep in the output.
- shape: The shape of the output segmentation. For tiled predictions this is inferred from the tile-local mask shapes if it is not passed.
- perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
- nms_thresh: The threshold for filtering out objects in NMS.
- max_size: The maximum mask size to keep in the output.
- intersection_over_min: Whether to perform intersection over the minimum overlap shape or to perform intersection over union.
Returns:
The segmentation obtained from merging the masks left after NMS.