micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from contextlib import contextmanager, nullcontext 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import imageio.v3 as imageio 10 11import torch 12from torch.optim import Optimizer 13from torch.utils.data import random_split 14from torch.utils.data import DataLoader, Dataset 15from torch.optim.lr_scheduler import _LRScheduler 16 17import torch_em 18from torch_em.util import load_data 19from torch_em.data.datasets.util import split_kwargs 20 21from elf.io import open_file 22 23try: 24 from qtpy.QtCore import QObject 25except Exception: 26 QObject = Any 27 28from . import sam_trainer as trainers 29from ..instance_segmentation import get_unetr 30from . import joint_sam_trainer as joint_trainers 31from ..util import get_device, get_model_names, export_custom_sam_model 32from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform 33 34 35FilePath = Union[str, os.PathLike] 36 37 38def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 39 x, _ = next(iter(loader)) 40 41 # Raw data: check that we have 1 or 3 channels. 42 n_channels = x.shape[1] 43 if n_channels not in (1, 3): 44 raise ValueError( 45 "Invalid number of channels for the input data from the data loader. " 46 f"Expect 1 or 3 channels, got {n_channels}." 47 ) 48 49 # Raw data: check that it is between [0, 255] 50 minval, maxval = x.min(), x.max() 51 if minval < 0 or minval > 255: 52 raise ValueError( 53 "Invalid data range for the input data from the data loader. " 54 f"The input has to be in range [0, 255], but got minimum value {minval}." 55 ) 56 if maxval < 1 or maxval > 255: 57 raise ValueError( 58 "Invalid data range for the input data from the data loader. " 59 f"The input has to be in range [0, 255], but got maximum value {maxval}." 60 ) 61 62 # Target data: the check depends on whether we train with or without decoder. 63 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 64 65 def _check_instance_channel(instance_channel): 66 unique_vals = torch.unique(instance_channel) 67 if (unique_vals < 0).any(): 68 raise ValueError( 69 "The target channel with the instance segmentation must not have negative values." 70 ) 71 if len(unique_vals) == 1: 72 raise ValueError( 73 "The target channel with the instance segmentation must have at least one instance." 74 ) 75 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 76 raise ValueError( 77 "All values in the target channel with the instance segmentation must be integer." 78 ) 79 80 counter = 0 81 name = "" if name is None else f"'{name}'" 82 for x, y in tqdm( 83 loader, 84 desc=f"Verifying labels in {name} dataloader", 85 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 86 ): 87 n_channels_y = y.shape[1] 88 if with_segmentation_decoder: 89 if n_channels_y != 4: 90 raise ValueError( 91 "Invalid number of channels in the target data from the data loader. " 92 "Expect 4 channel for training with an instance segmentation decoder, " 93 f"but got {n_channels_y} channels." 94 ) 95 # Check instance channel per sample in a batch 96 for per_y_sample in y: 97 _check_instance_channel(per_y_sample[0]) 98 99 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 100 if targets_min < 0 or targets_min > 1: 101 raise ValueError( 102 "Invalid value range in the target data from the value loader. " 103 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 104 f"to be in range [0.0, 1.0], but got min {targets_min}" 105 ) 106 if targets_max < 0 or targets_max > 1: 107 raise ValueError( 108 "Invalid value range in the target data from the value loader. " 109 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 110 f"to be in range [0.0, 1.0], but got max {targets_max}" 111 ) 112 113 else: 114 if n_channels_y != 1: 115 raise ValueError( 116 "Invalid number of channels in the target data from the data loader. " 117 "Expect 1 channel for training without an instance segmentation decoder, " 118 f"but got {n_channels_y} channels." 119 ) 120 # Check instance channel per sample in a batch 121 for per_y_sample in y: 122 _check_instance_channel(per_y_sample) 123 124 counter += 1 125 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 126 break 127 128 129# Make the progress bar callbacks compatible with a tqdm progress bar interface. 130class _ProgressBarWrapper: 131 def __init__(self, signals): 132 self._signals = signals 133 self._total = None 134 135 @property 136 def total(self): 137 return self._total 138 139 @total.setter 140 def total(self, value): 141 self._signals.pbar_total.emit(value) 142 self._total = value 143 144 def update(self, steps): 145 self._signals.pbar_update.emit(steps) 146 147 def set_description(self, desc, **kwargs): 148 self._signals.pbar_description.emit(desc) 149 150 151@contextmanager 152def _filter_warnings(ignore_warnings): 153 if ignore_warnings: 154 with warnings.catch_warnings(): 155 warnings.simplefilter("ignore") 156 yield 157 else: 158 with nullcontext(): 159 yield 160 161 162def _count_parameters(model_parameters): 163 params = sum(p.numel() for p in model_parameters if p.requires_grad) 164 params = params / 1e6 165 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 166 167 168def train_sam( 169 name: str, 170 model_type: str, 171 train_loader: DataLoader, 172 val_loader: DataLoader, 173 n_epochs: int = 100, 174 early_stopping: Optional[int] = 10, 175 n_objects_per_batch: Optional[int] = 25, 176 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 177 with_segmentation_decoder: bool = True, 178 freeze: Optional[List[str]] = None, 179 device: Optional[Union[str, torch.device]] = None, 180 lr: float = 1e-5, 181 n_sub_iteration: int = 8, 182 save_root: Optional[Union[str, os.PathLike]] = None, 183 mask_prob: float = 0.5, 184 n_iterations: Optional[int] = None, 185 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 186 scheduler_kwargs: Optional[Dict[str, Any]] = None, 187 save_every_kth_epoch: Optional[int] = None, 188 pbar_signals: Optional[QObject] = None, 189 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 190 peft_kwargs: Optional[Dict] = None, 191 ignore_warnings: bool = True, 192 verify_n_labels_in_loader: Optional[int] = 50, 193 box_distortion_factor: Optional[float] = 0.025, 194 overwrite_training: bool = True, 195 **model_kwargs, 196) -> None: 197 """Run training for a SAM model. 198 199 Args: 200 name: The name of the model to be trained. The checkpoint and logs will have this name. 201 model_type: The type of the SAM model. 202 train_loader: The dataloader for training. 203 val_loader: The dataloader for validation. 204 n_epochs: The number of epochs to train for. 205 early_stopping: Enable early stopping after this number of epochs without improvement. 206 n_objects_per_batch: The number of objects per batch used to compute 207 the loss for interative segmentation. If None all objects will be used, 208 if given objects will be randomly sub-sampled. 209 checkpoint_path: Path to checkpoint for initializing the SAM model. 210 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 211 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 212 By default nothing is frozen and the full model is updated. 213 device: The device to use for training. 214 lr: The learning rate. 215 n_sub_iteration: The number of iterative prompts per training iteration. 216 save_root: Optional root directory for saving the checkpoints and logs. 217 If not given the current working directory is used. 218 mask_prob: The probability for using a mask as input in a given training sub-iteration. 219 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 220 scheduler_class: The learning rate scheduler to update the learning rate. 221 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 222 scheduler_kwargs: The learning rate scheduler parameters. 223 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 224 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 225 pbar_signals: Controls for napari progress bar. 226 optimizer_class: The optimizer class. 227 By default, torch.optim.AdamW is used. 228 peft_kwargs: Keyword arguments for the PEFT wrapper class. 229 ignore_warnings: Whether to ignore raised warnings. 230 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 231 By default, 50 batches of labels are verified from the dataloaders. 232 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 233 overwrite_training: Whether to overwrite the trained model stored at the same location. 234 By default, overwrites the trained model at each run. 235 If set to 'False', it will avoid retraining the model if the previous run was completed. 236 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 237 """ 238 with _filter_warnings(ignore_warnings): 239 240 t_start = time.time() 241 242 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 243 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 244 245 device = get_device(device) 246 # Get the trainable segment anything model. 247 model, state = get_trainable_sam_model( 248 model_type=model_type, 249 device=device, 250 freeze=freeze, 251 checkpoint_path=checkpoint_path, 252 return_state=True, 253 peft_kwargs=peft_kwargs, 254 **model_kwargs 255 ) 256 257 # This class creates all the training data for a batch (inputs, prompts and labels). 258 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 259 260 # Create the UNETR decoder (if train with it) and the optimizer. 261 if with_segmentation_decoder: 262 263 # Get the UNETR. 264 unetr = get_unetr( 265 image_encoder=model.sam.image_encoder, 266 decoder_state=state.get("decoder_state", None), 267 device=device, 268 ) 269 270 # Get the parameters for SAM and the decoder from UNETR. 271 joint_model_params = [params for params in model.parameters()] # sam parameters 272 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 273 if not param_name.startswith("encoder"): 274 joint_model_params.append(params) 275 276 model_params = joint_model_params 277 else: 278 model_params = model.parameters() 279 280 optimizer = optimizer_class(model_params, lr=lr) 281 282 if scheduler_kwargs is None: 283 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 284 285 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 286 287 # The trainer which performs training and validation. 288 if with_segmentation_decoder: 289 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 290 trainer = joint_trainers.JointSamTrainer( 291 name=name, 292 save_root=save_root, 293 train_loader=train_loader, 294 val_loader=val_loader, 295 model=model, 296 optimizer=optimizer, 297 device=device, 298 lr_scheduler=scheduler, 299 logger=joint_trainers.JointSamLogger, 300 log_image_interval=100, 301 mixed_precision=True, 302 convert_inputs=convert_inputs, 303 n_objects_per_batch=n_objects_per_batch, 304 n_sub_iteration=n_sub_iteration, 305 compile_model=False, 306 unetr=unetr, 307 instance_loss=instance_seg_loss, 308 instance_metric=instance_seg_loss, 309 early_stopping=early_stopping, 310 mask_prob=mask_prob, 311 ) 312 else: 313 trainer = trainers.SamTrainer( 314 name=name, 315 train_loader=train_loader, 316 val_loader=val_loader, 317 model=model, 318 optimizer=optimizer, 319 device=device, 320 lr_scheduler=scheduler, 321 logger=trainers.SamLogger, 322 log_image_interval=100, 323 mixed_precision=True, 324 convert_inputs=convert_inputs, 325 n_objects_per_batch=n_objects_per_batch, 326 n_sub_iteration=n_sub_iteration, 327 compile_model=False, 328 early_stopping=early_stopping, 329 mask_prob=mask_prob, 330 save_root=save_root, 331 ) 332 333 if n_iterations is None: 334 trainer_fit_params = {"epochs": n_epochs} 335 else: 336 trainer_fit_params = {"iterations": n_iterations} 337 338 if save_every_kth_epoch is not None: 339 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 340 341 if pbar_signals is not None: 342 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 343 trainer_fit_params["progress"] = progress_bar_wrapper 344 345 # Avoid overwriting a trained model, if desired by the user. 346 trainer_fit_params["overwrite_training"] = overwrite_training 347 348 trainer.fit(**trainer_fit_params) 349 350 t_run = time.time() - t_start 351 hours = int(t_run // 3600) 352 minutes = int(t_run // 60) 353 seconds = int(round(t_run % 60, 0)) 354 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 355 356 357def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 358 if isinstance(raw_paths, (str, os.PathLike)): 359 path = raw_paths 360 else: 361 path = raw_paths[0] 362 assert isinstance(path, (str, os.PathLike)) 363 364 # Check the underlying data dimensionality. 365 if raw_key is None: # If no key is given then we assume it's an image file. 366 ndim = imageio.imread(path).ndim 367 else: # Otherwise we try to open the file from key. 368 try: # First try to open it with elf. 369 with open_file(path, "r") as f: 370 ndim = f[raw_key].ndim 371 except ValueError: # This may fail for images in a folder with different sizes. 372 # In that case we read one of the images. 373 image_path = glob(os.path.join(path, raw_key))[0] 374 ndim = imageio.imread(image_path).ndim 375 376 if not isinstance(patch_shape, tuple): 377 patch_shape = tuple(patch_shape) 378 379 if ndim == 2: 380 assert len(patch_shape) == 2 381 return patch_shape 382 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 383 return (1,) + patch_shape 384 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 385 return (1,) + patch_shape 386 else: 387 return patch_shape 388 389 390def default_sam_dataset( 391 raw_paths: Union[List[FilePath], FilePath], 392 raw_key: Optional[str], 393 label_paths: Union[List[FilePath], FilePath], 394 label_key: Optional[str], 395 patch_shape: Tuple[int], 396 with_segmentation_decoder: bool, 397 with_channels: Optional[bool] = None, 398 sampler: Optional[Callable] = None, 399 raw_transform: Optional[Callable] = None, 400 n_samples: Optional[int] = None, 401 is_train: bool = True, 402 min_size: int = 25, 403 max_sampling_attempts: Optional[int] = None, 404 **kwargs, 405) -> Dataset: 406 """Create a PyTorch Dataset for training a SAM model. 407 408 Args: 409 raw_paths: The path(s) to the image data used for training. 410 Can either be multiple 2D images or volumetric data. 411 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 412 or a glob pattern for selecting multiple files. 413 label_paths: The path(s) to the label data used for training. 414 Can either be multiple 2D images or volumetric data. 415 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 416 or a glob pattern for selecting multiple files. 417 patch_shape: The shape for training patches. 418 with_segmentation_decoder: Whether to train with additional segmentation decoder. 419 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 420 sampler: A sampler to reject batches according to a given criterion. 421 raw_transform: Transformation applied to the image data. 422 If not given the data will be cast to 8bit. 423 n_samples: The number of samples for this dataset. 424 is_train: Whether this dataset is used for training or validation. 425 min_size: Minimal object size. Smaller objects will be filtered. 426 max_sampling_attempts: Number of sampling attempts to make from a dataset. 427 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 428 429 Returns: 430 The segmentation dataset. 431 """ 432 433 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 434 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 435 436 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 437 # Get valid raw paths to make checks possible. 438 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 439 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 440 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 441 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 442 443 # Load one of the raw inputs to validate whether it is RGB or not. 444 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 445 if test_raw_inputs.ndim == 3: 446 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 447 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 448 # We need to provide a list of inputs to 'ImageCollectionDataset'. 449 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 450 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 451 452 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 453 with_channels = False if with_channels is None else with_channels 454 455 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 456 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 457 with_channels = True if with_channels is None else with_channels 458 459 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 460 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 461 with_channels = False if with_channels is None else with_channels 462 463 # Set the data transformations. 464 if raw_transform is None: 465 raw_transform = require_8bit 466 467 if with_segmentation_decoder: 468 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 469 distances=True, 470 boundary_distances=True, 471 directed_distances=False, 472 foreground=True, 473 instances=True, 474 min_size=min_size, 475 ) 476 else: 477 label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 478 479 # Set a default sampler if none was passed. 480 if sampler is None: 481 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 482 483 # Check the patch shape to add a singleton if required. 484 patch_shape = _update_patch_shape( 485 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 486 ) 487 488 # Set a minimum number of samples per epoch. 489 if n_samples is None: 490 loader = torch_em.default_segmentation_loader( 491 raw_paths=raw_paths, 492 raw_key=raw_key, 493 label_paths=label_paths, 494 label_key=label_key, 495 batch_size=1, 496 patch_shape=patch_shape, 497 with_channels=with_channels, 498 ndim=2, 499 is_seg_dataset=is_seg_dataset, 500 raw_transform=raw_transform, 501 **kwargs 502 ) 503 n_samples = max(len(loader), 100 if is_train else 5) 504 505 dataset = torch_em.default_segmentation_dataset( 506 raw_paths=raw_paths, 507 raw_key=raw_key, 508 label_paths=label_paths, 509 label_key=label_key, 510 patch_shape=patch_shape, 511 raw_transform=raw_transform, 512 label_transform=label_transform, 513 with_channels=with_channels, 514 ndim=2, 515 sampler=sampler, 516 n_samples=n_samples, 517 is_seg_dataset=is_seg_dataset, 518 **kwargs, 519 ) 520 521 if max_sampling_attempts is not None: 522 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 523 for ds in dataset.datasets: 524 ds.max_sampling_attempts = max_sampling_attempts 525 else: 526 dataset.max_sampling_attempts = max_sampling_attempts 527 528 return dataset 529 530 531def default_sam_loader(**kwargs) -> DataLoader: 532 """Create a PyTorch DataLoader for training a SAM model. 533 534 Args: 535 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 536 537 Returns: 538 The DataLoader. 539 """ 540 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 541 542 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 543 # which the users can provide to get their desired segmentation dataset. 544 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 545 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 546 547 ds = default_sam_dataset(**ds_kwargs) 548 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 549 550 551CONFIGURATIONS = { 552 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 553 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 554 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 555 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 556 "V100": {"model_type": "vit_b"}, 557 "A100": {"model_type": "vit_h"}, 558} 559 560 561def _find_best_configuration(): 562 if torch.cuda.is_available(): 563 564 # Check how much memory we have and select the best matching GPU 565 # for the available VRAM size. 566 _, vram = torch.cuda.mem_get_info() 567 vram = vram / 1e9 # in GB 568 569 # Maybe we can get more configurations in the future. 570 if vram > 80: # More than 80 GB: use the A100 configurations. 571 return "A100" 572 elif vram > 30: # More than 30 GB: use the V100 configurations. 573 return "V100" 574 elif vram > 14: # More than 14 GB: use the RTX5000 configurations. 575 return "rtx5000" 576 else: # Otherwise: not enough memory to train on the GPU, use CPU instead. 577 return "CPU" 578 else: 579 return "CPU" 580 581 582"""Best training configurations for given hardware resources. 583""" 584 585 586def train_sam_for_configuration( 587 name: str, 588 configuration: str, 589 train_loader: DataLoader, 590 val_loader: DataLoader, 591 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 592 with_segmentation_decoder: bool = True, 593 model_type: Optional[str] = None, 594 **kwargs, 595) -> None: 596 """Run training for a SAM model with the configuration for a given hardware resource. 597 598 Selects the best training settings for the given configuration. 599 The available configurations are listed in `CONFIGURATIONS`. 600 601 Args: 602 name: The name of the model to be trained. 603 The checkpoint and logs wil have this name. 604 configuration: The configuration (= name of hardware resource). 605 train_loader: The dataloader for training. 606 val_loader: The dataloader for validation. 607 checkpoint_path: Path to checkpoint for initializing the SAM model. 608 with_segmentation_decoder: Whether to train additional UNETR decoder 609 for automatic instance segmentation. 610 model_type: Over-ride the default model type. 611 This can be used to use one of the micro_sam models as starting point 612 instead of a default sam model. 613 kwargs: Additional keyword parameters that will be passed to `train_sam`. 614 """ 615 if configuration in CONFIGURATIONS: 616 train_kwargs = CONFIGURATIONS[configuration] 617 else: 618 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 619 620 if model_type is None: 621 model_type = train_kwargs.pop("model_type") 622 else: 623 expected_model_type = train_kwargs.pop("model_type") 624 if model_type[:5] != expected_model_type: 625 warnings.warn("You have specified a different model type.") 626 627 train_kwargs.update(**kwargs) 628 train_sam( 629 name=name, 630 train_loader=train_loader, 631 val_loader=val_loader, 632 checkpoint_path=checkpoint_path, 633 with_segmentation_decoder=with_segmentation_decoder, 634 model_type=model_type, 635 **train_kwargs 636 ) 637 638 639def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): 640 641 # Whether the model is stored in the current working directory or in another location. 642 if save_root is None: 643 save_root = os.getcwd() # Map this to current working directory, if not specified by the user. 644 645 # Get the 'best' model checkpoint ready for export. 646 best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") 647 if not os.path.exists(best_checkpoint): 648 raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.") 649 650 # Export the model if an output path has been given. 651 if output_path: 652 653 # If the filepath has a pytorch-specific ending, then we just export the checkpoint. 654 if os.path.splitext(output_path)[1] in (".pt", ".pth"): 655 export_custom_sam_model( 656 checkpoint_path=best_checkpoint, 657 model_type=model_type[:5], 658 save_path=output_path, 659 with_segmentation_decoder=with_segmentation_decoder, 660 ) 661 662 # Otherwise we export it as bioimage.io model. 663 else: 664 from micro_sam.bioimageio import export_sam_model 665 666 # Load image and corresponding labels from the val loader. 667 with torch.no_grad(): 668 image_data, label_data = next(iter(val_loader)) 669 image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze() 670 671 # Select the first channel of the label image if we have a channel axis, i.e. contains the labels 672 if label_data.ndim == 3: 673 label_data = label_data[0] # Gets the channel with instances. 674 assert image_data.shape == label_data.shape 675 label_data = label_data.astype("uint32") 676 677 export_sam_model( 678 image=image_data, 679 label_image=label_data, 680 model_type=model_type[:5], 681 name=checkpoint_name, 682 output_path=output_path, 683 checkpoint_path=best_checkpoint, 684 ) 685 686 # The final path where the model has been stored. 687 final_path = output_path 688 689 else: # If no exports have been made, inform the user about the best checkpoint. 690 final_path = best_checkpoint 691 692 return final_path 693 694 695def main(): 696 """@private""" 697 import argparse 698 699 available_models = list(get_model_names()) 700 available_models = ", ".join(available_models) 701 702 available_configurations = list(CONFIGURATIONS.keys()) 703 available_configurations = ", ".join(available_configurations) 704 705 parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.") 706 707 # Images and labels for training. 708 parser.add_argument( 709 "--images", required=True, type=str, nargs="*", 710 help="Filepath to images or the directory where the image data is stored." 711 ) 712 parser.add_argument( 713 "--labels", required=True, type=str, nargs="*", 714 help="Filepath to ground-truth labels or the directory where the label data is stored." 715 ) 716 parser.add_argument( 717 "--image_key", type=str, default=None, 718 help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. " 719 ) 720 parser.add_argument( 721 "--label_key", type=str, default=None, 722 help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. " 723 ) 724 725 # Images and labels for validation. 726 # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided. 727 # Users can choose to have their explicit validation set via this feature as well. 728 parser.add_argument( 729 "--val_images", type=str, nargs="*", 730 help="Filepath to images for validation or the directory where the image data is stored." 731 ) 732 parser.add_argument( 733 "--val_labels", type=str, nargs="*", 734 help="Filepath to ground-truth labels for validation or the directory where the label data is stored." 735 ) 736 parser.add_argument( 737 "--val_image_key", type=str, default=None, 738 help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file." 739 ) 740 parser.add_argument( 741 "--val_label_key", type=str, default=None, 742 help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file." 743 ) 744 745 # Other necessary stuff for training. 746 parser.add_argument( 747 "--configuration", type=str, default=_find_best_configuration(), 748 help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}." 749 ) 750 751 def none_or_str(value): 752 if value.lower() == 'none': 753 return None 754 return value 755 756 parser.add_argument( 757 "--segmentation_decoder", type=none_or_str, default="instances", 758 # TODO: in future, we can extend this to semantic seg / or even more advanced stuff. 759 help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. " 760 "By default, it uses the 'instances' option, i.e. trains with the additional segmentation decoder for " 761 "instance segmentation, otherwise pass 'None' for training without the additional segmentation decoder at all." 762 ) 763 764 # Optional advanced settings a user can opt to change the values for. 765 parser.add_argument( 766 "-d", "--device", type=str, default=None, 767 help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). " 768 "By default the most performant available device will be selected." 769 ) 770 parser.add_argument( 771 "--patch_shape", type=int, nargs="*", default=(512, 512), 772 help="The choice of patch shape for training Segment Anything Model. " 773 "By default, a patch size of 512x512 is used." 774 ) 775 parser.add_argument( 776 "-m", "--model_type", type=str, default=None, 777 help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}." 778 ) 779 parser.add_argument( 780 "--checkpoint_path", type=str, default=None, 781 help="Checkpoint from which the SAM model will be loaded for finetuning." 782 ) 783 parser.add_argument( 784 "-s", "--save_root", type=str, default=None, 785 help="The directory where the trained models and corresponding logs will be stored. " 786 "By default, there are stored in your current working directory." 787 ) 788 parser.add_argument( 789 "--trained_model_name", type=str, default="sam_model", 790 help="The custom name of trained model sub-folder. Allows users to have several trained models " 791 "under the same 'save_root'." 792 ) 793 parser.add_argument( 794 "--output_path", type=str, default=None, 795 help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model." 796 ) 797 parser.add_argument( 798 "--n_epochs", type=int, default=100, 799 help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs." 800 ) 801 parser.add_argument( 802 "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders." 803 ) 804 parser.add_argument( 805 "--batch_size", type=int, default=1, 806 help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1." 807 ) 808 parser.add_argument( 809 "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"), 810 help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images " 811 "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'." 812 ) 813 814 args = parser.parse_args() 815 816 # 1. Get all necessary stuff for training. 817 checkpoint_name = args.trained_model_name 818 config = args.configuration 819 model_type = args.model_type 820 checkpoint_path = args.checkpoint_path 821 batch_size = args.batch_size 822 patch_shape = args.patch_shape 823 epochs = args.n_epochs 824 num_workers = args.num_workers 825 device = args.device 826 save_root = args.save_root 827 output_path = args.output_path 828 829 if args.segmentation_decoder and args.segmentation_deocder != "instances": 830 raise ValueError( 831 "The 'segmentation_decoder' argument currently supports 'instances' as input argument only." 832 ) 833 with_segmentation_decoder = (args.segmentation_decoder is not None) 834 835 # Get image paths and corresponding keys. 836 train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key 837 val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key # noqa 838 839 # 2. Prepare the dataloaders. 840 841 # If the user wants to preprocess the inputs, we allow the possibility to do so. 842 _raw_transform = get_raw_transform(args.preprocess) 843 844 # Get the dataset with files for training. 845 dataset = default_sam_dataset( 846 raw_paths=train_images, 847 raw_key=train_image_key, 848 label_paths=train_gt, 849 label_key=train_gt_key, 850 patch_shape=patch_shape, 851 with_segmentation_decoder=with_segmentation_decoder, 852 raw_transform=_raw_transform, 853 ) 854 855 # If val images are not exclusively provided, we create a val split from the training data. 856 if val_images is None: 857 assert val_gt is None and val_image_key is None and val_gt_key is None 858 # Use 10% of the dataset for validation - at least one image - for validation. 859 n_val = max(1, int(0.1 * len(dataset))) 860 train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val]) 861 862 else: # If val images provided, we create a new dataset for it. 863 train_dataset = dataset 864 val_dataset = default_sam_dataset( 865 raw_paths=val_images, 866 raw_key=val_image_key, 867 label_paths=val_gt, 868 label_key=val_gt_key, 869 patch_shape=patch_shape, 870 with_segmentation_decoder=with_segmentation_decoder, 871 raw_transform=_raw_transform, 872 ) 873 874 # Get the dataloaders from the datasets. 875 train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 876 val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers) 877 878 # 3. Train the Segment Anything Model. 879 880 # Get a valid model and other necessary parameters for training. 881 if model_type is not None and model_type not in available_models: 882 raise ValueError(f"'{model_type}' is not a valid choice of model.") 883 if config is not None and config not in available_configurations: 884 raise ValueError(f"'{config}' is not a valid choice of configuration.") 885 886 if model_type is None: # If user does not specify the model, we use the default model corresponding to the config. 887 model_type = CONFIGURATIONS[config]["model_type"] 888 889 train_sam_for_configuration( 890 name=checkpoint_name, 891 configuration=config, 892 model_type=model_type, 893 train_loader=train_loader, 894 val_loader=val_loader, 895 n_epochs=epochs, 896 checkpoint_path=checkpoint_path, 897 with_segmentation_decoder=with_segmentation_decoder, 898 freeze=None, # TODO: Allow for PEFT. 899 device=device, 900 save_root=save_root, 901 peft_kwargs=None, # TODO: Allow for PEFT. 902 ) 903 904 # 4. Export the model, if desired by the user 905 final_path = _export_helper( 906 save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader 907 ) 908 909 print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath =
typing.Union[str, os.PathLike]
def
train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[str, os.PathLike, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, overwrite_training: bool = True, **model_kwargs) -> None:
169def train_sam( 170 name: str, 171 model_type: str, 172 train_loader: DataLoader, 173 val_loader: DataLoader, 174 n_epochs: int = 100, 175 early_stopping: Optional[int] = 10, 176 n_objects_per_batch: Optional[int] = 25, 177 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 178 with_segmentation_decoder: bool = True, 179 freeze: Optional[List[str]] = None, 180 device: Optional[Union[str, torch.device]] = None, 181 lr: float = 1e-5, 182 n_sub_iteration: int = 8, 183 save_root: Optional[Union[str, os.PathLike]] = None, 184 mask_prob: float = 0.5, 185 n_iterations: Optional[int] = None, 186 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 187 scheduler_kwargs: Optional[Dict[str, Any]] = None, 188 save_every_kth_epoch: Optional[int] = None, 189 pbar_signals: Optional[QObject] = None, 190 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 191 peft_kwargs: Optional[Dict] = None, 192 ignore_warnings: bool = True, 193 verify_n_labels_in_loader: Optional[int] = 50, 194 box_distortion_factor: Optional[float] = 0.025, 195 overwrite_training: bool = True, 196 **model_kwargs, 197) -> None: 198 """Run training for a SAM model. 199 200 Args: 201 name: The name of the model to be trained. The checkpoint and logs will have this name. 202 model_type: The type of the SAM model. 203 train_loader: The dataloader for training. 204 val_loader: The dataloader for validation. 205 n_epochs: The number of epochs to train for. 206 early_stopping: Enable early stopping after this number of epochs without improvement. 207 n_objects_per_batch: The number of objects per batch used to compute 208 the loss for interative segmentation. If None all objects will be used, 209 if given objects will be randomly sub-sampled. 210 checkpoint_path: Path to checkpoint for initializing the SAM model. 211 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 212 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 213 By default nothing is frozen and the full model is updated. 214 device: The device to use for training. 215 lr: The learning rate. 216 n_sub_iteration: The number of iterative prompts per training iteration. 217 save_root: Optional root directory for saving the checkpoints and logs. 218 If not given the current working directory is used. 219 mask_prob: The probability for using a mask as input in a given training sub-iteration. 220 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 221 scheduler_class: The learning rate scheduler to update the learning rate. 222 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 223 scheduler_kwargs: The learning rate scheduler parameters. 224 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 225 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 226 pbar_signals: Controls for napari progress bar. 227 optimizer_class: The optimizer class. 228 By default, torch.optim.AdamW is used. 229 peft_kwargs: Keyword arguments for the PEFT wrapper class. 230 ignore_warnings: Whether to ignore raised warnings. 231 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 232 By default, 50 batches of labels are verified from the dataloaders. 233 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 234 overwrite_training: Whether to overwrite the trained model stored at the same location. 235 By default, overwrites the trained model at each run. 236 If set to 'False', it will avoid retraining the model if the previous run was completed. 237 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 238 """ 239 with _filter_warnings(ignore_warnings): 240 241 t_start = time.time() 242 243 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 244 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 245 246 device = get_device(device) 247 # Get the trainable segment anything model. 248 model, state = get_trainable_sam_model( 249 model_type=model_type, 250 device=device, 251 freeze=freeze, 252 checkpoint_path=checkpoint_path, 253 return_state=True, 254 peft_kwargs=peft_kwargs, 255 **model_kwargs 256 ) 257 258 # This class creates all the training data for a batch (inputs, prompts and labels). 259 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 260 261 # Create the UNETR decoder (if train with it) and the optimizer. 262 if with_segmentation_decoder: 263 264 # Get the UNETR. 265 unetr = get_unetr( 266 image_encoder=model.sam.image_encoder, 267 decoder_state=state.get("decoder_state", None), 268 device=device, 269 ) 270 271 # Get the parameters for SAM and the decoder from UNETR. 272 joint_model_params = [params for params in model.parameters()] # sam parameters 273 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 274 if not param_name.startswith("encoder"): 275 joint_model_params.append(params) 276 277 model_params = joint_model_params 278 else: 279 model_params = model.parameters() 280 281 optimizer = optimizer_class(model_params, lr=lr) 282 283 if scheduler_kwargs is None: 284 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 285 286 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 287 288 # The trainer which performs training and validation. 289 if with_segmentation_decoder: 290 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 291 trainer = joint_trainers.JointSamTrainer( 292 name=name, 293 save_root=save_root, 294 train_loader=train_loader, 295 val_loader=val_loader, 296 model=model, 297 optimizer=optimizer, 298 device=device, 299 lr_scheduler=scheduler, 300 logger=joint_trainers.JointSamLogger, 301 log_image_interval=100, 302 mixed_precision=True, 303 convert_inputs=convert_inputs, 304 n_objects_per_batch=n_objects_per_batch, 305 n_sub_iteration=n_sub_iteration, 306 compile_model=False, 307 unetr=unetr, 308 instance_loss=instance_seg_loss, 309 instance_metric=instance_seg_loss, 310 early_stopping=early_stopping, 311 mask_prob=mask_prob, 312 ) 313 else: 314 trainer = trainers.SamTrainer( 315 name=name, 316 train_loader=train_loader, 317 val_loader=val_loader, 318 model=model, 319 optimizer=optimizer, 320 device=device, 321 lr_scheduler=scheduler, 322 logger=trainers.SamLogger, 323 log_image_interval=100, 324 mixed_precision=True, 325 convert_inputs=convert_inputs, 326 n_objects_per_batch=n_objects_per_batch, 327 n_sub_iteration=n_sub_iteration, 328 compile_model=False, 329 early_stopping=early_stopping, 330 mask_prob=mask_prob, 331 save_root=save_root, 332 ) 333 334 if n_iterations is None: 335 trainer_fit_params = {"epochs": n_epochs} 336 else: 337 trainer_fit_params = {"iterations": n_iterations} 338 339 if save_every_kth_epoch is not None: 340 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 341 342 if pbar_signals is not None: 343 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 344 trainer_fit_params["progress"] = progress_bar_wrapper 345 346 # Avoid overwriting a trained model, if desired by the user. 347 trainer_fit_params["overwrite_training"] = overwrite_training 348 349 trainer.fit(**trainer_fit_params) 350 351 t_run = time.time() - t_start 352 hours = int(t_run // 3600) 353 minutes = int(t_run // 60) 354 seconds = int(round(t_run % 60, 0)) 355 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training.
- lr: The learning rate.
- n_sub_iteration: The number of iterative prompts per training iteration.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration.
- n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
- scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
- scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
- save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
.
def
default_sam_dataset( raw_paths: Union[List[Union[os.PathLike, str]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[os.PathLike, str]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
391def default_sam_dataset( 392 raw_paths: Union[List[FilePath], FilePath], 393 raw_key: Optional[str], 394 label_paths: Union[List[FilePath], FilePath], 395 label_key: Optional[str], 396 patch_shape: Tuple[int], 397 with_segmentation_decoder: bool, 398 with_channels: Optional[bool] = None, 399 sampler: Optional[Callable] = None, 400 raw_transform: Optional[Callable] = None, 401 n_samples: Optional[int] = None, 402 is_train: bool = True, 403 min_size: int = 25, 404 max_sampling_attempts: Optional[int] = None, 405 **kwargs, 406) -> Dataset: 407 """Create a PyTorch Dataset for training a SAM model. 408 409 Args: 410 raw_paths: The path(s) to the image data used for training. 411 Can either be multiple 2D images or volumetric data. 412 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 413 or a glob pattern for selecting multiple files. 414 label_paths: The path(s) to the label data used for training. 415 Can either be multiple 2D images or volumetric data. 416 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 417 or a glob pattern for selecting multiple files. 418 patch_shape: The shape for training patches. 419 with_segmentation_decoder: Whether to train with additional segmentation decoder. 420 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 421 sampler: A sampler to reject batches according to a given criterion. 422 raw_transform: Transformation applied to the image data. 423 If not given the data will be cast to 8bit. 424 n_samples: The number of samples for this dataset. 425 is_train: Whether this dataset is used for training or validation. 426 min_size: Minimal object size. Smaller objects will be filtered. 427 max_sampling_attempts: Number of sampling attempts to make from a dataset. 428 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 429 430 Returns: 431 The segmentation dataset. 432 """ 433 434 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 435 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 436 437 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 438 # Get valid raw paths to make checks possible. 439 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 440 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 441 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 442 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 443 444 # Load one of the raw inputs to validate whether it is RGB or not. 445 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 446 if test_raw_inputs.ndim == 3: 447 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 448 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 449 # We need to provide a list of inputs to 'ImageCollectionDataset'. 450 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 451 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 452 453 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 454 with_channels = False if with_channels is None else with_channels 455 456 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 457 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 458 with_channels = True if with_channels is None else with_channels 459 460 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 461 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 462 with_channels = False if with_channels is None else with_channels 463 464 # Set the data transformations. 465 if raw_transform is None: 466 raw_transform = require_8bit 467 468 if with_segmentation_decoder: 469 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 470 distances=True, 471 boundary_distances=True, 472 directed_distances=False, 473 foreground=True, 474 instances=True, 475 min_size=min_size, 476 ) 477 else: 478 label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 479 480 # Set a default sampler if none was passed. 481 if sampler is None: 482 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 483 484 # Check the patch shape to add a singleton if required. 485 patch_shape = _update_patch_shape( 486 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 487 ) 488 489 # Set a minimum number of samples per epoch. 490 if n_samples is None: 491 loader = torch_em.default_segmentation_loader( 492 raw_paths=raw_paths, 493 raw_key=raw_key, 494 label_paths=label_paths, 495 label_key=label_key, 496 batch_size=1, 497 patch_shape=patch_shape, 498 with_channels=with_channels, 499 ndim=2, 500 is_seg_dataset=is_seg_dataset, 501 raw_transform=raw_transform, 502 **kwargs 503 ) 504 n_samples = max(len(loader), 100 if is_train else 5) 505 506 dataset = torch_em.default_segmentation_dataset( 507 raw_paths=raw_paths, 508 raw_key=raw_key, 509 label_paths=label_paths, 510 label_key=label_key, 511 patch_shape=patch_shape, 512 raw_transform=raw_transform, 513 label_transform=label_transform, 514 with_channels=with_channels, 515 ndim=2, 516 sampler=sampler, 517 n_samples=n_samples, 518 is_seg_dataset=is_seg_dataset, 519 **kwargs, 520 ) 521 522 if max_sampling_attempts is not None: 523 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 524 for ds in dataset.datasets: 525 ds.max_sampling_attempts = max_sampling_attempts 526 else: 527 dataset.max_sampling_attempts = max_sampling_attempts 528 529 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
- sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation.
- min_size: Minimal object size. Smaller objects will be filtered.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
532def default_sam_loader(**kwargs) -> DataLoader: 533 """Create a PyTorch DataLoader for training a SAM model. 534 535 Args: 536 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 537 538 Returns: 539 The DataLoader. 540 """ 541 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 542 543 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 544 # which the users can provide to get their desired segmentation dataset. 545 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 546 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 547 548 ds = default_sam_dataset(**ds_kwargs) 549 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.
CONFIGURATIONS =
{'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
def
train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
587def train_sam_for_configuration( 588 name: str, 589 configuration: str, 590 train_loader: DataLoader, 591 val_loader: DataLoader, 592 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 593 with_segmentation_decoder: bool = True, 594 model_type: Optional[str] = None, 595 **kwargs, 596) -> None: 597 """Run training for a SAM model with the configuration for a given hardware resource. 598 599 Selects the best training settings for the given configuration. 600 The available configurations are listed in `CONFIGURATIONS`. 601 602 Args: 603 name: The name of the model to be trained. 604 The checkpoint and logs wil have this name. 605 configuration: The configuration (= name of hardware resource). 606 train_loader: The dataloader for training. 607 val_loader: The dataloader for validation. 608 checkpoint_path: Path to checkpoint for initializing the SAM model. 609 with_segmentation_decoder: Whether to train additional UNETR decoder 610 for automatic instance segmentation. 611 model_type: Over-ride the default model type. 612 This can be used to use one of the micro_sam models as starting point 613 instead of a default sam model. 614 kwargs: Additional keyword parameters that will be passed to `train_sam`. 615 """ 616 if configuration in CONFIGURATIONS: 617 train_kwargs = CONFIGURATIONS[configuration] 618 else: 619 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 620 621 if model_type is None: 622 model_type = train_kwargs.pop("model_type") 623 else: 624 expected_model_type = train_kwargs.pop("model_type") 625 if model_type[:5] != expected_model_type: 626 warnings.warn("You have specified a different model type.") 627 628 train_kwargs.update(**kwargs) 629 train_sam( 630 name=name, 631 train_loader=train_loader, 632 val_loader=val_loader, 633 checkpoint_path=checkpoint_path, 634 with_segmentation_decoder=with_segmentation_decoder, 635 model_type=model_type, 636 **train_kwargs 637 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS
.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- configuration: The configuration (= name of hardware resource).
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam
.