micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from contextlib import contextmanager, nullcontext 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import imageio.v3 as imageio 10 11import torch 12from torch.optim import Optimizer 13from torch.utils.data import random_split 14from torch.utils.data import DataLoader, Dataset 15from torch.optim.lr_scheduler import _LRScheduler 16 17import torch_em 18from torch_em.util import load_data 19from torch_em.data.datasets.util import split_kwargs 20 21from elf.io import open_file 22 23try: 24 from qtpy.QtCore import QObject 25except Exception: 26 QObject = Any 27 28from . import sam_trainer as trainers 29from ..instance_segmentation import get_unetr 30from . import joint_sam_trainer as joint_trainers 31from ..util import get_device, get_model_names, export_custom_sam_model 32from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform 33 34 35FilePath = Union[str, os.PathLike] 36 37 38def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 39 x, _ = next(iter(loader)) 40 41 # Raw data: check that we have 1 or 3 channels. 42 n_channels = x.shape[1] 43 if n_channels not in (1, 3): 44 raise ValueError( 45 "Invalid number of channels for the input data from the data loader. " 46 f"Expect 1 or 3 channels, got {n_channels}." 47 ) 48 49 # Raw data: check that it is between [0, 255] 50 minval, maxval = x.min(), x.max() 51 if minval < 0 or minval > 255: 52 raise ValueError( 53 "Invalid data range for the input data from the data loader. " 54 f"The input has to be in range [0, 255], but got minimum value {minval}." 55 ) 56 if maxval < 1 or maxval > 255: 57 raise ValueError( 58 "Invalid data range for the input data from the data loader. " 59 f"The input has to be in range [0, 255], but got maximum value {maxval}." 60 ) 61 62 # Target data: the check depends on whether we train with or without decoder. 63 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 64 65 def _check_instance_channel(instance_channel): 66 unique_vals = torch.unique(instance_channel) 67 if (unique_vals < 0).any(): 68 raise ValueError( 69 "The target channel with the instance segmentation must not have negative values." 70 ) 71 if len(unique_vals) == 1: 72 raise ValueError( 73 "The target channel with the instance segmentation must have at least one instance." 74 ) 75 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 76 raise ValueError( 77 "All values in the target channel with the instance segmentation must be integer." 78 ) 79 80 counter = 0 81 name = "" if name is None else f"'{name}'" 82 for x, y in tqdm( 83 loader, 84 desc=f"Verifying labels in {name} dataloader", 85 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 86 ): 87 n_channels_y = y.shape[1] 88 if with_segmentation_decoder: 89 if n_channels_y != 4: 90 raise ValueError( 91 "Invalid number of channels in the target data from the data loader. " 92 "Expect 4 channel for training with an instance segmentation decoder, " 93 f"but got {n_channels_y} channels." 94 ) 95 # Check instance channel per sample in a batch 96 for per_y_sample in y: 97 _check_instance_channel(per_y_sample[0]) 98 99 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 100 if targets_min < 0 or targets_min > 1: 101 raise ValueError( 102 "Invalid value range in the target data from the value loader. " 103 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 104 f"to be in range [0.0, 1.0], but got min {targets_min}" 105 ) 106 if targets_max < 0 or targets_max > 1: 107 raise ValueError( 108 "Invalid value range in the target data from the value loader. " 109 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 110 f"to be in range [0.0, 1.0], but got max {targets_max}" 111 ) 112 113 else: 114 if n_channels_y != 1: 115 raise ValueError( 116 "Invalid number of channels in the target data from the data loader. " 117 "Expect 1 channel for training without an instance segmentation decoder, " 118 f"but got {n_channels_y} channels." 119 ) 120 # Check instance channel per sample in a batch 121 for per_y_sample in y: 122 _check_instance_channel(per_y_sample) 123 124 counter += 1 125 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 126 break 127 128 129# Make the progress bar callbacks compatible with a tqdm progress bar interface. 130class _ProgressBarWrapper: 131 def __init__(self, signals): 132 self._signals = signals 133 self._total = None 134 135 @property 136 def total(self): 137 return self._total 138 139 @total.setter 140 def total(self, value): 141 self._signals.pbar_total.emit(value) 142 self._total = value 143 144 def update(self, steps): 145 self._signals.pbar_update.emit(steps) 146 147 def set_description(self, desc, **kwargs): 148 self._signals.pbar_description.emit(desc) 149 150 151@contextmanager 152def _filter_warnings(ignore_warnings): 153 if ignore_warnings: 154 with warnings.catch_warnings(): 155 warnings.simplefilter("ignore") 156 yield 157 else: 158 with nullcontext(): 159 yield 160 161 162def _count_parameters(model_parameters): 163 params = sum(p.numel() for p in model_parameters if p.requires_grad) 164 params = params / 1e6 165 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 166 167 168def train_sam( 169 name: str, 170 model_type: str, 171 train_loader: DataLoader, 172 val_loader: DataLoader, 173 n_epochs: int = 100, 174 early_stopping: Optional[int] = 10, 175 n_objects_per_batch: Optional[int] = 25, 176 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 177 with_segmentation_decoder: bool = True, 178 freeze: Optional[List[str]] = None, 179 device: Optional[Union[str, torch.device]] = None, 180 lr: float = 1e-5, 181 n_sub_iteration: int = 8, 182 save_root: Optional[Union[str, os.PathLike]] = None, 183 mask_prob: float = 0.5, 184 n_iterations: Optional[int] = None, 185 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 186 scheduler_kwargs: Optional[Dict[str, Any]] = None, 187 save_every_kth_epoch: Optional[int] = None, 188 pbar_signals: Optional[QObject] = None, 189 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 190 peft_kwargs: Optional[Dict] = None, 191 ignore_warnings: bool = True, 192 verify_n_labels_in_loader: Optional[int] = 50, 193 box_distortion_factor: Optional[float] = 0.025, 194 **model_kwargs, 195) -> None: 196 """Run training for a SAM model. 197 198 Args: 199 name: The name of the model to be trained. The checkpoint and logs will have this name. 200 model_type: The type of the SAM model. 201 train_loader: The dataloader for training. 202 val_loader: The dataloader for validation. 203 n_epochs: The number of epochs to train for. 204 early_stopping: Enable early stopping after this number of epochs without improvement. 205 n_objects_per_batch: The number of objects per batch used to compute 206 the loss for interative segmentation. If None all objects will be used, 207 if given objects will be randomly sub-sampled. 208 checkpoint_path: Path to checkpoint for initializing the SAM model. 209 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 210 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 211 By default nothing is frozen and the full model is updated. 212 device: The device to use for training. 213 lr: The learning rate. 214 n_sub_iteration: The number of iterative prompts per training iteration. 215 save_root: Optional root directory for saving the checkpoints and logs. 216 If not given the current working directory is used. 217 mask_prob: The probability for using a mask as input in a given training sub-iteration. 218 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 219 scheduler_class: The learning rate scheduler to update the learning rate. 220 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 221 scheduler_kwargs: The learning rate scheduler parameters. 222 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 223 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 224 pbar_signals: Controls for napari progress bar. 225 optimizer_class: The optimizer class. 226 By default, torch.optim.AdamW is used. 227 peft_kwargs: Keyword arguments for the PEFT wrapper class. 228 ignore_warnings: Whether to ignore raised warnings. 229 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 230 By default, 50 batches of labels are verified from the dataloaders. 231 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 232 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 233 """ 234 with _filter_warnings(ignore_warnings): 235 236 t_start = time.time() 237 238 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 239 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 240 241 device = get_device(device) 242 # Get the trainable segment anything model. 243 model, state = get_trainable_sam_model( 244 model_type=model_type, 245 device=device, 246 freeze=freeze, 247 checkpoint_path=checkpoint_path, 248 return_state=True, 249 peft_kwargs=peft_kwargs, 250 **model_kwargs 251 ) 252 253 # This class creates all the training data for a batch (inputs, prompts and labels). 254 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 255 256 # Create the UNETR decoder (if train with it) and the optimizer. 257 if with_segmentation_decoder: 258 259 # Get the UNETR. 260 unetr = get_unetr( 261 image_encoder=model.sam.image_encoder, 262 decoder_state=state.get("decoder_state", None), 263 device=device, 264 ) 265 266 # Get the parameters for SAM and the decoder from UNETR. 267 joint_model_params = [params for params in model.parameters()] # sam parameters 268 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 269 if not param_name.startswith("encoder"): 270 joint_model_params.append(params) 271 272 model_params = joint_model_params 273 else: 274 model_params = model.parameters() 275 276 optimizer = optimizer_class(model_params, lr=lr) 277 278 if scheduler_kwargs is None: 279 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 280 281 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 282 283 # The trainer which performs training and validation. 284 if with_segmentation_decoder: 285 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 286 trainer = joint_trainers.JointSamTrainer( 287 name=name, 288 save_root=save_root, 289 train_loader=train_loader, 290 val_loader=val_loader, 291 model=model, 292 optimizer=optimizer, 293 device=device, 294 lr_scheduler=scheduler, 295 logger=joint_trainers.JointSamLogger, 296 log_image_interval=100, 297 mixed_precision=True, 298 convert_inputs=convert_inputs, 299 n_objects_per_batch=n_objects_per_batch, 300 n_sub_iteration=n_sub_iteration, 301 compile_model=False, 302 unetr=unetr, 303 instance_loss=instance_seg_loss, 304 instance_metric=instance_seg_loss, 305 early_stopping=early_stopping, 306 mask_prob=mask_prob, 307 ) 308 else: 309 trainer = trainers.SamTrainer( 310 name=name, 311 train_loader=train_loader, 312 val_loader=val_loader, 313 model=model, 314 optimizer=optimizer, 315 device=device, 316 lr_scheduler=scheduler, 317 logger=trainers.SamLogger, 318 log_image_interval=100, 319 mixed_precision=True, 320 convert_inputs=convert_inputs, 321 n_objects_per_batch=n_objects_per_batch, 322 n_sub_iteration=n_sub_iteration, 323 compile_model=False, 324 early_stopping=early_stopping, 325 mask_prob=mask_prob, 326 save_root=save_root, 327 ) 328 329 if n_iterations is None: 330 trainer_fit_params = {"epochs": n_epochs} 331 else: 332 trainer_fit_params = {"iterations": n_iterations} 333 334 if save_every_kth_epoch is not None: 335 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 336 337 if pbar_signals is not None: 338 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 339 trainer_fit_params["progress"] = progress_bar_wrapper 340 341 trainer.fit(**trainer_fit_params) 342 343 t_run = time.time() - t_start 344 hours = int(t_run // 3600) 345 minutes = int(t_run // 60) 346 seconds = int(round(t_run % 60, 0)) 347 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 348 349 350def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 351 if isinstance(raw_paths, (str, os.PathLike)): 352 path = raw_paths 353 else: 354 path = raw_paths[0] 355 assert isinstance(path, (str, os.PathLike)) 356 357 # Check the underlying data dimensionality. 358 if raw_key is None: # If no key is given then we assume it's an image file. 359 ndim = imageio.imread(path).ndim 360 else: # Otherwise we try to open the file from key. 361 try: # First try to open it with elf. 362 with open_file(path, "r") as f: 363 ndim = f[raw_key].ndim 364 except ValueError: # This may fail for images in a folder with different sizes. 365 # In that case we read one of the images. 366 image_path = glob(os.path.join(path, raw_key))[0] 367 ndim = imageio.imread(image_path).ndim 368 369 if not isinstance(patch_shape, tuple): 370 patch_shape = tuple(patch_shape) 371 372 if ndim == 2: 373 assert len(patch_shape) == 2 374 return patch_shape 375 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 376 return (1,) + patch_shape 377 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 378 return (1,) + patch_shape 379 else: 380 return patch_shape 381 382 383def default_sam_dataset( 384 raw_paths: Union[List[FilePath], FilePath], 385 raw_key: Optional[str], 386 label_paths: Union[List[FilePath], FilePath], 387 label_key: Optional[str], 388 patch_shape: Tuple[int], 389 with_segmentation_decoder: bool, 390 with_channels: Optional[bool] = None, 391 sampler: Optional[Callable] = None, 392 raw_transform: Optional[Callable] = None, 393 n_samples: Optional[int] = None, 394 is_train: bool = True, 395 min_size: int = 25, 396 max_sampling_attempts: Optional[int] = None, 397 **kwargs, 398) -> Dataset: 399 """Create a PyTorch Dataset for training a SAM model. 400 401 Args: 402 raw_paths: The path(s) to the image data used for training. 403 Can either be multiple 2D images or volumetric data. 404 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 405 or a glob pattern for selecting multiple files. 406 label_paths: The path(s) to the label data used for training. 407 Can either be multiple 2D images or volumetric data. 408 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 409 or a glob pattern for selecting multiple files. 410 patch_shape: The shape for training patches. 411 with_segmentation_decoder: Whether to train with additional segmentation decoder. 412 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 413 sampler: A sampler to reject batches according to a given criterion. 414 raw_transform: Transformation applied to the image data. 415 If not given the data will be cast to 8bit. 416 n_samples: The number of samples for this dataset. 417 is_train: Whether this dataset is used for training or validation. 418 min_size: Minimal object size. Smaller objects will be filtered. 419 max_sampling_attempts: Number of sampling attempts to make from a dataset. 420 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 421 422 Returns: 423 The segmentation dataset. 424 """ 425 426 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 427 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 428 429 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 430 # Get valid raw paths to make checks possible. 431 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 432 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 433 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 434 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 435 436 # Load one of the raw inputs to validate whether it is RGB or not. 437 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 438 if test_raw_inputs.ndim == 3: 439 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 440 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 441 # We need to provide a list of inputs to 'ImageCollectionDataset'. 442 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 443 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 444 445 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 446 with_channels = False if with_channels is None else with_channels 447 448 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 449 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 450 with_channels = True if with_channels is None else with_channels 451 452 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 453 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 454 with_channels = False if with_channels is None else with_channels 455 456 # Set the data transformations. 457 if raw_transform is None: 458 raw_transform = require_8bit 459 460 if with_segmentation_decoder: 461 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 462 distances=True, 463 boundary_distances=True, 464 directed_distances=False, 465 foreground=True, 466 instances=True, 467 min_size=min_size, 468 ) 469 else: 470 label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 471 472 # Set a default sampler if none was passed. 473 if sampler is None: 474 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 475 476 # Check the patch shape to add a singleton if required. 477 patch_shape = _update_patch_shape( 478 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 479 ) 480 481 # Set a minimum number of samples per epoch. 482 if n_samples is None: 483 loader = torch_em.default_segmentation_loader( 484 raw_paths=raw_paths, 485 raw_key=raw_key, 486 label_paths=label_paths, 487 label_key=label_key, 488 batch_size=1, 489 patch_shape=patch_shape, 490 with_channels=with_channels, 491 ndim=2, 492 is_seg_dataset=is_seg_dataset, 493 raw_transform=raw_transform, 494 **kwargs 495 ) 496 n_samples = max(len(loader), 100 if is_train else 5) 497 498 dataset = torch_em.default_segmentation_dataset( 499 raw_paths=raw_paths, 500 raw_key=raw_key, 501 label_paths=label_paths, 502 label_key=label_key, 503 patch_shape=patch_shape, 504 raw_transform=raw_transform, 505 label_transform=label_transform, 506 with_channels=with_channels, 507 ndim=2, 508 sampler=sampler, 509 n_samples=n_samples, 510 is_seg_dataset=is_seg_dataset, 511 **kwargs, 512 ) 513 514 if max_sampling_attempts is not None: 515 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 516 for ds in dataset.datasets: 517 ds.max_sampling_attempts = max_sampling_attempts 518 else: 519 dataset.max_sampling_attempts = max_sampling_attempts 520 521 return dataset 522 523 524def default_sam_loader(**kwargs) -> DataLoader: 525 """Create a PyTorch DataLoader for training a SAM model. 526 527 Args: 528 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 529 530 Returns: 531 The DataLoader. 532 """ 533 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 534 535 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 536 # which the users can provide to get their desired segmentation dataset. 537 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 538 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 539 540 ds = default_sam_dataset(**ds_kwargs) 541 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 542 543 544CONFIGURATIONS = { 545 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 546 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 547 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 548 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 549 "V100": {"model_type": "vit_b"}, 550 "A100": {"model_type": "vit_h"}, 551} 552 553 554def _find_best_configuration(): 555 if torch.cuda.is_available(): 556 557 # Check how much memory we have and select the best matching GPU 558 # for the available VRAM size. 559 _, vram = torch.cuda.mem_get_info() 560 vram = vram / 1e9 # in GB 561 562 # Maybe we can get more configurations in the future. 563 if vram > 80: # More than 80 GB: use the A100 configurations. 564 return "A100" 565 elif vram > 30: # More than 30 GB: use the V100 configurations. 566 return "V100" 567 elif vram > 14: # More than 14 GB: use the RTX5000 configurations. 568 return "rtx5000" 569 else: # Otherwise: not enough memory to train on the GPU, use CPU instead. 570 return "CPU" 571 else: 572 return "CPU" 573 574 575"""Best training configurations for given hardware resources. 576""" 577 578 579def train_sam_for_configuration( 580 name: str, 581 configuration: str, 582 train_loader: DataLoader, 583 val_loader: DataLoader, 584 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 585 with_segmentation_decoder: bool = True, 586 model_type: Optional[str] = None, 587 **kwargs, 588) -> None: 589 """Run training for a SAM model with the configuration for a given hardware resource. 590 591 Selects the best training settings for the given configuration. 592 The available configurations are listed in `CONFIGURATIONS`. 593 594 Args: 595 name: The name of the model to be trained. 596 The checkpoint and logs wil have this name. 597 configuration: The configuration (= name of hardware resource). 598 train_loader: The dataloader for training. 599 val_loader: The dataloader for validation. 600 checkpoint_path: Path to checkpoint for initializing the SAM model. 601 with_segmentation_decoder: Whether to train additional UNETR decoder 602 for automatic instance segmentation. 603 model_type: Over-ride the default model type. 604 This can be used to use one of the micro_sam models as starting point 605 instead of a default sam model. 606 kwargs: Additional keyword parameters that will be passed to `train_sam`. 607 """ 608 if configuration in CONFIGURATIONS: 609 train_kwargs = CONFIGURATIONS[configuration] 610 else: 611 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 612 613 if model_type is None: 614 model_type = train_kwargs.pop("model_type") 615 else: 616 expected_model_type = train_kwargs.pop("model_type") 617 if model_type[:5] != expected_model_type: 618 warnings.warn("You have specified a different model type.") 619 620 train_kwargs.update(**kwargs) 621 train_sam( 622 name=name, 623 train_loader=train_loader, 624 val_loader=val_loader, 625 checkpoint_path=checkpoint_path, 626 with_segmentation_decoder=with_segmentation_decoder, 627 model_type=model_type, 628 **train_kwargs 629 ) 630 631 632def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): 633 634 # Whether the model is stored in the current working directory or in another location. 635 if save_root is None: 636 save_root = os.getcwd() # Map this to current working directory, if not specified by the user. 637 638 # Get the 'best' model checkpoint ready for export. 639 best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") 640 if not os.path.exists(best_checkpoint): 641 raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.") 642 643 # Export the model if an output path has been given. 644 if output_path: 645 646 # If the filepath has a pytorch-specific ending, then we just export the checkpoint. 647 if os.path.splitext(output_path)[1] in (".pt", ".pth"): 648 export_custom_sam_model( 649 checkpoint_path=best_checkpoint, 650 model_type=model_type[:5], 651 save_path=output_path, 652 with_segmentation_decoder=with_segmentation_decoder, 653 ) 654 655 # Otherwise we export it as bioimage.io model. 656 else: 657 from micro_sam.bioimageio import export_sam_model 658 659 # Load image and corresponding labels from the val loader. 660 with torch.no_grad(): 661 image_data, label_data = next(iter(val_loader)) 662 image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze() 663 664 # Select the first channel of the label image if we have a channel axis, i.e. contains the labels 665 if label_data.ndim == 3: 666 label_data = label_data[0] # Gets the channel with instances. 667 assert image_data.shape == label_data.shape 668 label_data = label_data.astype("uint32") 669 670 export_sam_model( 671 image=image_data, 672 label_image=label_data, 673 model_type=model_type[:5], 674 name=checkpoint_name, 675 output_path=output_path, 676 checkpoint_path=best_checkpoint, 677 ) 678 679 # The final path where the model has been stored. 680 final_path = output_path 681 682 else: # If no exports have been made, inform the user about the best checkpoint. 683 final_path = best_checkpoint 684 685 return final_path 686 687 688def main(): 689 """@private""" 690 import argparse 691 692 available_models = list(get_model_names()) 693 available_models = ", ".join(available_models) 694 695 available_configurations = list(CONFIGURATIONS.keys()) 696 available_configurations = ", ".join(available_configurations) 697 698 parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.") 699 700 # Images and labels for training. 701 parser.add_argument( 702 "--images", required=True, type=str, nargs="*", 703 help="Filepath to images or the directory where the image data is stored." 704 ) 705 parser.add_argument( 706 "--labels", required=True, type=str, nargs="*", 707 help="Filepath to ground-truth labels or the directory where the label data is stored." 708 ) 709 parser.add_argument( 710 "--image_key", type=str, default=None, 711 help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. " 712 ) 713 parser.add_argument( 714 "--label_key", type=str, default=None, 715 help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. " 716 ) 717 718 # Images and labels for validation. 719 # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided. 720 # Users can choose to have their explicit validation set via this feature as well. 721 parser.add_argument( 722 "--val_images", type=str, nargs="*", 723 help="Filepath to images for validation or the directory where the image data is stored." 724 ) 725 parser.add_argument( 726 "--val_labels", type=str, nargs="*", 727 help="Filepath to ground-truth labels for validation or the directory where the label data is stored." 728 ) 729 parser.add_argument( 730 "--val_image_key", type=str, default=None, 731 help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file." 732 ) 733 parser.add_argument( 734 "--val_label_key", type=str, default=None, 735 help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file." 736 ) 737 738 # Other necessary stuff for training. 739 parser.add_argument( 740 "--configuration", type=str, default=_find_best_configuration(), 741 help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}." 742 ) 743 parser.add_argument( 744 "--segmentation_decoder", type=str, default="instances", # TODO: in future, we can extend this to semantic seg. 745 help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. " 746 "By default, it trains with the additional segmentation decoder for instance segmentation." 747 ) 748 749 # Optional advanced settings a user can opt to change the values for. 750 parser.add_argument( 751 "-d", "--device", type=str, default=None, 752 help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). " 753 "By default the most performant available device will be selected." 754 ) 755 parser.add_argument( 756 "--patch_shape", type=int, nargs="*", default=(512, 512), 757 help="The choice of patch shape for training Segment Anything." 758 ) 759 parser.add_argument( 760 "-m", "--model_type", type=str, default=None, 761 help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}." 762 ) 763 parser.add_argument( 764 "--checkpoint_path", type=str, default=None, 765 help="Checkpoint from which the SAM model will be loaded for finetuning." 766 ) 767 parser.add_argument( 768 "-s", "--save_root", type=str, default=None, 769 help="The directory where the trained models and corresponding logs will be stored. " 770 "By default, there are stored in your current working directory." 771 ) 772 parser.add_argument( 773 "--trained_model_name", type=str, default="sam_model", 774 help="The custom name of trained model. Allows users to have several trained models under the same 'save_root'." 775 ) 776 parser.add_argument( 777 "--output_path", type=str, default=None, 778 help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model." 779 ) 780 parser.add_argument( 781 "--n_epochs", type=int, default=100, 782 help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs." 783 ) 784 parser.add_argument( 785 "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders." 786 ) 787 parser.add_argument( 788 "--batch_size", type=int, default=1, 789 help="The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1." 790 ) 791 parser.add_argument( 792 "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"), 793 help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images " 794 "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'." 795 ) 796 797 args = parser.parse_args() 798 799 # 1. Get all necessary stuff for training. 800 checkpoint_name = args.trained_model_name 801 config = args.configuration 802 model_type = args.model_type 803 checkpoint_path = args.checkpoint_path 804 batch_size = args.batch_size 805 patch_shape = args.patch_shape 806 epochs = args.n_epochs 807 num_workers = args.num_workers 808 device = args.device 809 save_root = args.save_root 810 output_path = args.output_path 811 with_segmentation_decoder = (args.segmentation_decoder == "instances") 812 813 # Get image paths and corresponding keys. 814 train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key 815 val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key # noqa 816 817 # 2. Prepare the dataloaders. 818 819 # If the user wants to preprocess the inputs, we allow the possibility to do so. 820 _raw_transform = get_raw_transform(args.preprocess) 821 822 # Get the dataset with files for training. 823 dataset = default_sam_dataset( 824 raw_paths=train_images, 825 raw_key=train_image_key, 826 label_paths=train_gt, 827 label_key=train_gt_key, 828 patch_shape=patch_shape, 829 with_segmentation_decoder=with_segmentation_decoder, 830 raw_transform=_raw_transform, 831 ) 832 833 # If val images are not exclusively provided, we create a val split from the training data. 834 if val_images is None: 835 assert val_gt is None and val_image_key is None and val_gt_key is None 836 # Use 10% of the dataset for validation - at least one image - for validation. 837 n_val = max(1, int(0.1 * len(dataset))) 838 train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val]) 839 840 else: # If val images provided, we create a new dataset for it. 841 train_dataset = dataset 842 val_dataset = default_sam_dataset( 843 raw_paths=val_images, 844 raw_key=val_image_key, 845 label_paths=val_gt, 846 label_key=val_gt_key, 847 patch_shape=patch_shape, 848 with_segmentation_decoder=with_segmentation_decoder, 849 raw_transform=_raw_transform, 850 ) 851 852 # Get the dataloaders from the datasets. 853 train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 854 val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers) 855 856 # 3. Train the Segment Anything Model. 857 858 # Get a valid model and other necessary parameters for training. 859 if model_type is not None and model_type not in available_models: 860 raise ValueError(f"'{model_type}' is not a valid choice of model.") 861 if config is not None and config not in available_configurations: 862 raise ValueError(f"'{config}' is not a valid choice of configuration.") 863 864 if model_type is None: # If user does not specify the model, we use the default model corresponding to the config. 865 model_type = CONFIGURATIONS[config]["model_type"] 866 867 train_sam_for_configuration( 868 name=checkpoint_name, 869 configuration=config, 870 model_type=model_type, 871 train_loader=train_loader, 872 val_loader=val_loader, 873 n_epochs=epochs, 874 checkpoint_path=checkpoint_path, 875 with_segmentation_decoder=with_segmentation_decoder, 876 freeze=None, # TODO: Allow for PEFT. 877 device=device, 878 save_root=save_root, 879 peft_kwargs=None, # TODO: Allow for PEFT. 880 ) 881 882 # 4. Export the model, if desired by the user 883 final_path = _export_helper( 884 save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader 885 ) 886 887 print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath =
typing.Union[str, os.PathLike]
def
train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, **model_kwargs) -> None:
169def train_sam( 170 name: str, 171 model_type: str, 172 train_loader: DataLoader, 173 val_loader: DataLoader, 174 n_epochs: int = 100, 175 early_stopping: Optional[int] = 10, 176 n_objects_per_batch: Optional[int] = 25, 177 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 178 with_segmentation_decoder: bool = True, 179 freeze: Optional[List[str]] = None, 180 device: Optional[Union[str, torch.device]] = None, 181 lr: float = 1e-5, 182 n_sub_iteration: int = 8, 183 save_root: Optional[Union[str, os.PathLike]] = None, 184 mask_prob: float = 0.5, 185 n_iterations: Optional[int] = None, 186 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 187 scheduler_kwargs: Optional[Dict[str, Any]] = None, 188 save_every_kth_epoch: Optional[int] = None, 189 pbar_signals: Optional[QObject] = None, 190 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 191 peft_kwargs: Optional[Dict] = None, 192 ignore_warnings: bool = True, 193 verify_n_labels_in_loader: Optional[int] = 50, 194 box_distortion_factor: Optional[float] = 0.025, 195 **model_kwargs, 196) -> None: 197 """Run training for a SAM model. 198 199 Args: 200 name: The name of the model to be trained. The checkpoint and logs will have this name. 201 model_type: The type of the SAM model. 202 train_loader: The dataloader for training. 203 val_loader: The dataloader for validation. 204 n_epochs: The number of epochs to train for. 205 early_stopping: Enable early stopping after this number of epochs without improvement. 206 n_objects_per_batch: The number of objects per batch used to compute 207 the loss for interative segmentation. If None all objects will be used, 208 if given objects will be randomly sub-sampled. 209 checkpoint_path: Path to checkpoint for initializing the SAM model. 210 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 211 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 212 By default nothing is frozen and the full model is updated. 213 device: The device to use for training. 214 lr: The learning rate. 215 n_sub_iteration: The number of iterative prompts per training iteration. 216 save_root: Optional root directory for saving the checkpoints and logs. 217 If not given the current working directory is used. 218 mask_prob: The probability for using a mask as input in a given training sub-iteration. 219 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 220 scheduler_class: The learning rate scheduler to update the learning rate. 221 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 222 scheduler_kwargs: The learning rate scheduler parameters. 223 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 224 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 225 pbar_signals: Controls for napari progress bar. 226 optimizer_class: The optimizer class. 227 By default, torch.optim.AdamW is used. 228 peft_kwargs: Keyword arguments for the PEFT wrapper class. 229 ignore_warnings: Whether to ignore raised warnings. 230 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 231 By default, 50 batches of labels are verified from the dataloaders. 232 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 233 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 234 """ 235 with _filter_warnings(ignore_warnings): 236 237 t_start = time.time() 238 239 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 240 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 241 242 device = get_device(device) 243 # Get the trainable segment anything model. 244 model, state = get_trainable_sam_model( 245 model_type=model_type, 246 device=device, 247 freeze=freeze, 248 checkpoint_path=checkpoint_path, 249 return_state=True, 250 peft_kwargs=peft_kwargs, 251 **model_kwargs 252 ) 253 254 # This class creates all the training data for a batch (inputs, prompts and labels). 255 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 256 257 # Create the UNETR decoder (if train with it) and the optimizer. 258 if with_segmentation_decoder: 259 260 # Get the UNETR. 261 unetr = get_unetr( 262 image_encoder=model.sam.image_encoder, 263 decoder_state=state.get("decoder_state", None), 264 device=device, 265 ) 266 267 # Get the parameters for SAM and the decoder from UNETR. 268 joint_model_params = [params for params in model.parameters()] # sam parameters 269 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 270 if not param_name.startswith("encoder"): 271 joint_model_params.append(params) 272 273 model_params = joint_model_params 274 else: 275 model_params = model.parameters() 276 277 optimizer = optimizer_class(model_params, lr=lr) 278 279 if scheduler_kwargs is None: 280 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 281 282 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 283 284 # The trainer which performs training and validation. 285 if with_segmentation_decoder: 286 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 287 trainer = joint_trainers.JointSamTrainer( 288 name=name, 289 save_root=save_root, 290 train_loader=train_loader, 291 val_loader=val_loader, 292 model=model, 293 optimizer=optimizer, 294 device=device, 295 lr_scheduler=scheduler, 296 logger=joint_trainers.JointSamLogger, 297 log_image_interval=100, 298 mixed_precision=True, 299 convert_inputs=convert_inputs, 300 n_objects_per_batch=n_objects_per_batch, 301 n_sub_iteration=n_sub_iteration, 302 compile_model=False, 303 unetr=unetr, 304 instance_loss=instance_seg_loss, 305 instance_metric=instance_seg_loss, 306 early_stopping=early_stopping, 307 mask_prob=mask_prob, 308 ) 309 else: 310 trainer = trainers.SamTrainer( 311 name=name, 312 train_loader=train_loader, 313 val_loader=val_loader, 314 model=model, 315 optimizer=optimizer, 316 device=device, 317 lr_scheduler=scheduler, 318 logger=trainers.SamLogger, 319 log_image_interval=100, 320 mixed_precision=True, 321 convert_inputs=convert_inputs, 322 n_objects_per_batch=n_objects_per_batch, 323 n_sub_iteration=n_sub_iteration, 324 compile_model=False, 325 early_stopping=early_stopping, 326 mask_prob=mask_prob, 327 save_root=save_root, 328 ) 329 330 if n_iterations is None: 331 trainer_fit_params = {"epochs": n_epochs} 332 else: 333 trainer_fit_params = {"iterations": n_iterations} 334 335 if save_every_kth_epoch is not None: 336 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 337 338 if pbar_signals is not None: 339 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 340 trainer_fit_params["progress"] = progress_bar_wrapper 341 342 trainer.fit(**trainer_fit_params) 343 344 t_run = time.time() - t_start 345 hours = int(t_run // 3600) 346 minutes = int(t_run // 60) 347 seconds = int(round(t_run % 60, 0)) 348 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training.
- lr: The learning rate.
- n_sub_iteration: The number of iterative prompts per training iteration.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration.
- n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
- scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
- scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
- save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
.
def
default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
384def default_sam_dataset( 385 raw_paths: Union[List[FilePath], FilePath], 386 raw_key: Optional[str], 387 label_paths: Union[List[FilePath], FilePath], 388 label_key: Optional[str], 389 patch_shape: Tuple[int], 390 with_segmentation_decoder: bool, 391 with_channels: Optional[bool] = None, 392 sampler: Optional[Callable] = None, 393 raw_transform: Optional[Callable] = None, 394 n_samples: Optional[int] = None, 395 is_train: bool = True, 396 min_size: int = 25, 397 max_sampling_attempts: Optional[int] = None, 398 **kwargs, 399) -> Dataset: 400 """Create a PyTorch Dataset for training a SAM model. 401 402 Args: 403 raw_paths: The path(s) to the image data used for training. 404 Can either be multiple 2D images or volumetric data. 405 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 406 or a glob pattern for selecting multiple files. 407 label_paths: The path(s) to the label data used for training. 408 Can either be multiple 2D images or volumetric data. 409 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 410 or a glob pattern for selecting multiple files. 411 patch_shape: The shape for training patches. 412 with_segmentation_decoder: Whether to train with additional segmentation decoder. 413 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 414 sampler: A sampler to reject batches according to a given criterion. 415 raw_transform: Transformation applied to the image data. 416 If not given the data will be cast to 8bit. 417 n_samples: The number of samples for this dataset. 418 is_train: Whether this dataset is used for training or validation. 419 min_size: Minimal object size. Smaller objects will be filtered. 420 max_sampling_attempts: Number of sampling attempts to make from a dataset. 421 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 422 423 Returns: 424 The segmentation dataset. 425 """ 426 427 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 428 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 429 430 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 431 # Get valid raw paths to make checks possible. 432 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 433 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 434 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 435 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 436 437 # Load one of the raw inputs to validate whether it is RGB or not. 438 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 439 if test_raw_inputs.ndim == 3: 440 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 441 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 442 # We need to provide a list of inputs to 'ImageCollectionDataset'. 443 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 444 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 445 446 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 447 with_channels = False if with_channels is None else with_channels 448 449 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 450 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 451 with_channels = True if with_channels is None else with_channels 452 453 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 454 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 455 with_channels = False if with_channels is None else with_channels 456 457 # Set the data transformations. 458 if raw_transform is None: 459 raw_transform = require_8bit 460 461 if with_segmentation_decoder: 462 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 463 distances=True, 464 boundary_distances=True, 465 directed_distances=False, 466 foreground=True, 467 instances=True, 468 min_size=min_size, 469 ) 470 else: 471 label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 472 473 # Set a default sampler if none was passed. 474 if sampler is None: 475 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 476 477 # Check the patch shape to add a singleton if required. 478 patch_shape = _update_patch_shape( 479 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 480 ) 481 482 # Set a minimum number of samples per epoch. 483 if n_samples is None: 484 loader = torch_em.default_segmentation_loader( 485 raw_paths=raw_paths, 486 raw_key=raw_key, 487 label_paths=label_paths, 488 label_key=label_key, 489 batch_size=1, 490 patch_shape=patch_shape, 491 with_channels=with_channels, 492 ndim=2, 493 is_seg_dataset=is_seg_dataset, 494 raw_transform=raw_transform, 495 **kwargs 496 ) 497 n_samples = max(len(loader), 100 if is_train else 5) 498 499 dataset = torch_em.default_segmentation_dataset( 500 raw_paths=raw_paths, 501 raw_key=raw_key, 502 label_paths=label_paths, 503 label_key=label_key, 504 patch_shape=patch_shape, 505 raw_transform=raw_transform, 506 label_transform=label_transform, 507 with_channels=with_channels, 508 ndim=2, 509 sampler=sampler, 510 n_samples=n_samples, 511 is_seg_dataset=is_seg_dataset, 512 **kwargs, 513 ) 514 515 if max_sampling_attempts is not None: 516 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 517 for ds in dataset.datasets: 518 ds.max_sampling_attempts = max_sampling_attempts 519 else: 520 dataset.max_sampling_attempts = max_sampling_attempts 521 522 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
- sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation.
- min_size: Minimal object size. Smaller objects will be filtered.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
525def default_sam_loader(**kwargs) -> DataLoader: 526 """Create a PyTorch DataLoader for training a SAM model. 527 528 Args: 529 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 530 531 Returns: 532 The DataLoader. 533 """ 534 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 535 536 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 537 # which the users can provide to get their desired segmentation dataset. 538 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 539 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 540 541 ds = default_sam_dataset(**ds_kwargs) 542 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.
CONFIGURATIONS =
{'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
def
train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
580def train_sam_for_configuration( 581 name: str, 582 configuration: str, 583 train_loader: DataLoader, 584 val_loader: DataLoader, 585 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 586 with_segmentation_decoder: bool = True, 587 model_type: Optional[str] = None, 588 **kwargs, 589) -> None: 590 """Run training for a SAM model with the configuration for a given hardware resource. 591 592 Selects the best training settings for the given configuration. 593 The available configurations are listed in `CONFIGURATIONS`. 594 595 Args: 596 name: The name of the model to be trained. 597 The checkpoint and logs wil have this name. 598 configuration: The configuration (= name of hardware resource). 599 train_loader: The dataloader for training. 600 val_loader: The dataloader for validation. 601 checkpoint_path: Path to checkpoint for initializing the SAM model. 602 with_segmentation_decoder: Whether to train additional UNETR decoder 603 for automatic instance segmentation. 604 model_type: Over-ride the default model type. 605 This can be used to use one of the micro_sam models as starting point 606 instead of a default sam model. 607 kwargs: Additional keyword parameters that will be passed to `train_sam`. 608 """ 609 if configuration in CONFIGURATIONS: 610 train_kwargs = CONFIGURATIONS[configuration] 611 else: 612 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 613 614 if model_type is None: 615 model_type = train_kwargs.pop("model_type") 616 else: 617 expected_model_type = train_kwargs.pop("model_type") 618 if model_type[:5] != expected_model_type: 619 warnings.warn("You have specified a different model type.") 620 621 train_kwargs.update(**kwargs) 622 train_sam( 623 name=name, 624 train_loader=train_loader, 625 val_loader=val_loader, 626 checkpoint_path=checkpoint_path, 627 with_segmentation_decoder=with_segmentation_decoder, 628 model_type=model_type, 629 **train_kwargs 630 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS
.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- configuration: The configuration (= name of hardware resource).
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam
.