micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from contextlib import contextmanager, nullcontext 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import imageio.v3 as imageio 10 11import torch 12from torch.optim import Optimizer 13from torch.optim.lr_scheduler import _LRScheduler 14from torch.utils.data import DataLoader, Dataset 15 16import torch_em 17from torch_em.data.datasets.util import split_kwargs 18 19from elf.io import open_file 20 21try: 22 from qtpy.QtCore import QObject 23except Exception: 24 QObject = Any 25 26from ..util import get_device 27from . import sam_trainer as trainers 28from ..instance_segmentation import get_unetr 29from . import joint_sam_trainer as joint_trainers 30from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit 31 32 33FilePath = Union[str, os.PathLike] 34 35 36def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 37 x, _ = next(iter(loader)) 38 39 # Raw data: check that we have 1 or 3 channels. 40 n_channels = x.shape[1] 41 if n_channels not in (1, 3): 42 raise ValueError( 43 "Invalid number of channels for the input data from the data loader. " 44 f"Expect 1 or 3 channels, got {n_channels}." 45 ) 46 47 # Raw data: check that it is between [0, 255] 48 minval, maxval = x.min(), x.max() 49 if minval < 0 or minval > 255: 50 raise ValueError( 51 "Invalid data range for the input data from the data loader. " 52 f"The input has to be in range [0, 255], but got minimum value {minval}." 53 ) 54 if maxval < 1 or maxval > 255: 55 raise ValueError( 56 "Invalid data range for the input data from the data loader. " 57 f"The input has to be in range [0, 255], but got maximum value {maxval}." 58 ) 59 60 # Target data: the check depends on whether we train with or without decoder. 61 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 62 63 def _check_instance_channel(instance_channel): 64 unique_vals = torch.unique(instance_channel) 65 if (unique_vals < 0).any(): 66 raise ValueError( 67 "The target channel with the instance segmentation must not have negative values." 68 ) 69 if len(unique_vals) == 1: 70 raise ValueError( 71 "The target channel with the instance segmentation must have at least one instance." 72 ) 73 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 74 raise ValueError( 75 "All values in the target channel with the instance segmentation must be integer." 76 ) 77 78 counter = 0 79 name = "" if name is None else f"'{name}'" 80 for x, y in tqdm( 81 loader, 82 desc=f"Verifying labels in {name} dataloader", 83 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 84 ): 85 n_channels_y = y.shape[1] 86 if with_segmentation_decoder: 87 if n_channels_y != 4: 88 raise ValueError( 89 "Invalid number of channels in the target data from the data loader. " 90 "Expect 4 channel for training with an instance segmentation decoder, " 91 f"but got {n_channels_y} channels." 92 ) 93 # Check instance channel per sample in a batch 94 for per_y_sample in y: 95 _check_instance_channel(per_y_sample[0]) 96 97 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 98 if targets_min < 0 or targets_min > 1: 99 raise ValueError( 100 "Invalid value range in the target data from the value loader. " 101 "Expect the 3 last target channels (for normalized distances and foreground probabilities)" 102 f"to be in range [0.0, 1.0], but got min {targets_min}" 103 ) 104 if targets_max < 0 or targets_max > 1: 105 raise ValueError( 106 "Invalid value range in the target data from the value loader. " 107 "Expect the 3 last target channels (for normalized distances and foreground probabilities)" 108 f"to be in range [0.0, 1.0], but got max {targets_max}" 109 ) 110 111 else: 112 if n_channels_y != 1: 113 raise ValueError( 114 "Invalid number of channels in the target data from the data loader. " 115 "Expect 1 channel for training without an instance segmentation decoder," 116 f"but got {n_channels_y} channels." 117 ) 118 # Check instance channel per sample in a batch 119 for per_y_sample in y: 120 _check_instance_channel(per_y_sample) 121 122 counter += 1 123 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 124 break 125 126 127# Make the progress bar callbacks compatible with a tqdm progress bar interface. 128class _ProgressBarWrapper: 129 def __init__(self, signals): 130 self._signals = signals 131 self._total = None 132 133 @property 134 def total(self): 135 return self._total 136 137 @total.setter 138 def total(self, value): 139 self._signals.pbar_total.emit(value) 140 self._total = value 141 142 def update(self, steps): 143 self._signals.pbar_update.emit(steps) 144 145 def set_description(self, desc, **kwargs): 146 self._signals.pbar_description.emit(desc) 147 148 149@contextmanager 150def _filter_warnings(ignore_warnings): 151 if ignore_warnings: 152 with warnings.catch_warnings(): 153 warnings.simplefilter("ignore") 154 yield 155 else: 156 with nullcontext(): 157 yield 158 159 160def _count_parameters(model_parameters): 161 params = sum(p.numel() for p in model_parameters if p.requires_grad) 162 params = params / 1e6 163 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 164 165 166def train_sam( 167 name: str, 168 model_type: str, 169 train_loader: DataLoader, 170 val_loader: DataLoader, 171 n_epochs: int = 100, 172 early_stopping: Optional[int] = 10, 173 n_objects_per_batch: Optional[int] = 25, 174 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 175 with_segmentation_decoder: bool = True, 176 freeze: Optional[List[str]] = None, 177 device: Optional[Union[str, torch.device]] = None, 178 lr: float = 1e-5, 179 n_sub_iteration: int = 8, 180 save_root: Optional[Union[str, os.PathLike]] = None, 181 mask_prob: float = 0.5, 182 n_iterations: Optional[int] = None, 183 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 184 scheduler_kwargs: Optional[Dict[str, Any]] = None, 185 save_every_kth_epoch: Optional[int] = None, 186 pbar_signals: Optional[QObject] = None, 187 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 188 peft_kwargs: Optional[Dict] = None, 189 ignore_warnings: bool = True, 190 verify_n_labels_in_loader: Optional[int] = 50, 191 **model_kwargs, 192) -> None: 193 """Run training for a SAM model. 194 195 Args: 196 name: The name of the model to be trained. 197 The checkpoint and logs wil have this name. 198 model_type: The type of the SAM model. 199 train_loader: The dataloader for training. 200 val_loader: The dataloader for validation. 201 n_epochs: The number of epochs to train for. 202 early_stopping: Enable early stopping after this number of epochs 203 without improvement. 204 n_objects_per_batch: The number of objects per batch used to compute 205 the loss for interative segmentation. If None all objects will be used, 206 if given objects will be randomly sub-sampled. 207 checkpoint_path: Path to checkpoint for initializing the SAM model. 208 with_segmentation_decoder: Whether to train additional UNETR decoder 209 for automatic instance segmentation. 210 freeze: Specify parts of the model that should be frozen, namely: 211 image_encoder, prompt_encoder and mask_decoder 212 By default nothing is frozen and the full model is updated. 213 device: The device to use for training. 214 lr: The learning rate. 215 n_sub_iteration: The number of iterative prompts per training iteration. 216 save_root: Optional root directory for saving the checkpoints and logs. 217 If not given the current working directory is used. 218 mask_prob: The probability for using a mask as input in a given training sub-iteration. 219 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 220 scheduler_class: The learning rate scheduler to update the learning rate. 221 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 222 scheduler_kwargs: The learning rate scheduler parameters. 223 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 224 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 225 pbar_signals: Controls for napari progress bar. 226 optimizer_class: The optimizer class. 227 By default, torch.optim.AdamW is used. 228 peft_kwargs: Keyword arguments for the PEFT wrapper class. 229 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 230 By default, 50 batches of labels are verified from the dataloaders. 231 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 232 ignore_warnings: Whether to ignore raised warnings. 233 """ 234 with _filter_warnings(ignore_warnings): 235 236 t_start = time.time() 237 238 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 239 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 240 241 device = get_device(device) 242 # Get the trainable segment anything model. 243 model, state = get_trainable_sam_model( 244 model_type=model_type, 245 device=device, 246 freeze=freeze, 247 checkpoint_path=checkpoint_path, 248 return_state=True, 249 peft_kwargs=peft_kwargs, 250 **model_kwargs 251 ) 252 253 # This class creates all the training data for a batch (inputs, prompts and labels). 254 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) 255 256 # Create the UNETR decoder (if train with it) and the optimizer. 257 if with_segmentation_decoder: 258 259 # Get the UNETR. 260 unetr = get_unetr( 261 image_encoder=model.sam.image_encoder, 262 decoder_state=state.get("decoder_state", None), 263 device=device, 264 ) 265 266 # Get the parameters for SAM and the decoder from UNETR. 267 joint_model_params = [params for params in model.parameters()] # sam parameters 268 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 269 if not param_name.startswith("encoder"): 270 joint_model_params.append(params) 271 272 optimizer = optimizer_class(joint_model_params, lr=lr) 273 274 else: 275 optimizer = optimizer_class(model.parameters(), lr=lr) 276 277 if scheduler_kwargs is None: 278 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 279 280 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 281 282 # The trainer which performs training and validation. 283 if with_segmentation_decoder: 284 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 285 trainer = joint_trainers.JointSamTrainer( 286 name=name, 287 save_root=save_root, 288 train_loader=train_loader, 289 val_loader=val_loader, 290 model=model, 291 optimizer=optimizer, 292 device=device, 293 lr_scheduler=scheduler, 294 logger=joint_trainers.JointSamLogger, 295 log_image_interval=100, 296 mixed_precision=True, 297 convert_inputs=convert_inputs, 298 n_objects_per_batch=n_objects_per_batch, 299 n_sub_iteration=n_sub_iteration, 300 compile_model=False, 301 unetr=unetr, 302 instance_loss=instance_seg_loss, 303 instance_metric=instance_seg_loss, 304 early_stopping=early_stopping, 305 mask_prob=mask_prob, 306 ) 307 else: 308 trainer = trainers.SamTrainer( 309 name=name, 310 train_loader=train_loader, 311 val_loader=val_loader, 312 model=model, 313 optimizer=optimizer, 314 device=device, 315 lr_scheduler=scheduler, 316 logger=trainers.SamLogger, 317 log_image_interval=100, 318 mixed_precision=True, 319 convert_inputs=convert_inputs, 320 n_objects_per_batch=n_objects_per_batch, 321 n_sub_iteration=n_sub_iteration, 322 compile_model=False, 323 early_stopping=early_stopping, 324 mask_prob=mask_prob, 325 save_root=save_root, 326 ) 327 328 if n_iterations is None: 329 trainer_fit_params = {"epochs": n_epochs} 330 else: 331 trainer_fit_params = {"iterations": n_iterations} 332 333 if save_every_kth_epoch is not None: 334 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 335 336 if pbar_signals is not None: 337 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 338 trainer_fit_params["progress"] = progress_bar_wrapper 339 340 trainer.fit(**trainer_fit_params) 341 342 t_run = time.time() - t_start 343 hours = int(t_run // 3600) 344 minutes = int(t_run // 60) 345 seconds = int(round(t_run % 60, 0)) 346 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 347 348 349def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 350 if isinstance(raw_paths, (str, os.PathLike)): 351 path = raw_paths 352 else: 353 path = raw_paths[0] 354 assert isinstance(path, (str, os.PathLike)) 355 356 # Check the underlying data dimensionality. 357 if raw_key is None: # If no key is given then we assume it's an image file. 358 ndim = imageio.imread(path).ndim 359 else: # Otherwise we try to open the file from key. 360 try: # First try to open it with elf. 361 with open_file(path, "r") as f: 362 ndim = f[raw_key].ndim 363 except ValueError: # This may fail for images in a folder with different sizes. 364 # In that case we read one of the images. 365 image_path = glob(os.path.join(path, raw_key))[0] 366 ndim = imageio.imread(image_path).ndim 367 368 if ndim == 2: 369 assert len(patch_shape) == 2 370 return patch_shape 371 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 372 return (1,) + patch_shape 373 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 374 return (1,) + patch_shape 375 else: 376 return patch_shape 377 378 379def default_sam_dataset( 380 raw_paths: Union[List[FilePath], FilePath], 381 raw_key: Optional[str], 382 label_paths: Union[List[FilePath], FilePath], 383 label_key: Optional[str], 384 patch_shape: Tuple[int], 385 with_segmentation_decoder: bool, 386 with_channels: bool = False, 387 sampler: Optional[Callable] = None, 388 raw_transform: Optional[Callable] = None, 389 n_samples: Optional[int] = None, 390 is_train: bool = True, 391 min_size: int = 25, 392 max_sampling_attempts: Optional[int] = None, 393 **kwargs, 394) -> Dataset: 395 """Create a PyTorch Dataset for training a SAM model. 396 397 Args: 398 raw_paths: The path(s) to the image data used for training. 399 Can either be multiple 2D images or volumetric data. 400 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 401 or a glob pattern for selecting multiple files. 402 label_paths: The path(s) to the label data used for training. 403 Can either be multiple 2D images or volumetric data. 404 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 405 or a glob pattern for selecting multiple files. 406 patch_shape: The shape for training patches. 407 with_segmentation_decoder: Whether to train with additional segmentation decoder. 408 with_channels: Whether the image data has RGB channels. 409 sampler: A sampler to reject batches according to a given criterion. 410 raw_transform: Transformation applied to the image data. 411 If not given the data will be cast to 8bit. 412 n_samples: The number of samples for this dataset. 413 is_train: Whether this dataset is used for training or validation. 414 min_size: Minimal object size. Smaller objects will be filtered. 415 max_sampling_attempts: Number of sampling attempts to make from a dataset. 416 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 417 418 Returns: 419 The segmentation dataset. 420 """ 421 422 # Set the data transformations. 423 if raw_transform is None: 424 raw_transform = require_8bit 425 426 if with_segmentation_decoder: 427 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 428 distances=True, boundary_distances=True, directed_distances=False, 429 foreground=True, instances=True, min_size=min_size, 430 ) 431 else: 432 label_transform = torch_em.transform.label.MinSizeLabelTransform( 433 min_size=min_size 434 ) 435 436 # Set a default sampler if none was passed. 437 if sampler is None: 438 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 439 440 # Check the patch shape to add a singleton if required. 441 patch_shape = _update_patch_shape( 442 patch_shape, raw_paths, raw_key, with_channels 443 ) 444 445 # Set a minimum number of samples per epoch. 446 if n_samples is None: 447 loader = torch_em.default_segmentation_loader( 448 raw_paths=raw_paths, 449 raw_key=raw_key, 450 label_paths=label_paths, 451 label_key=label_key, 452 batch_size=1, 453 patch_shape=patch_shape, 454 ndim=2, 455 **kwargs 456 ) 457 n_samples = max(len(loader), 100 if is_train else 5) 458 459 dataset = torch_em.default_segmentation_dataset( 460 raw_paths=raw_paths, 461 raw_key=raw_key, 462 label_paths=label_paths, 463 label_key=label_key, 464 patch_shape=patch_shape, 465 raw_transform=raw_transform, 466 label_transform=label_transform, 467 with_channels=with_channels, 468 ndim=2, 469 sampler=sampler, 470 n_samples=n_samples, 471 **kwargs, 472 ) 473 474 if max_sampling_attempts is not None: 475 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 476 for ds in dataset.datasets: 477 ds.max_sampling_attempts = max_sampling_attempts 478 else: 479 dataset.max_sampling_attempts = max_sampling_attempts 480 481 return dataset 482 483 484def default_sam_loader(**kwargs) -> DataLoader: 485 """Create a PyTorch DataLoader for training a SAM model. 486 487 Args: 488 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 489 490 Returns: 491 The DataLoader. 492 """ 493 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 494 495 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 496 # which the users can provide to get their desired segmentation dataset. 497 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 498 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 499 500 ds = default_sam_dataset(**ds_kwargs) 501 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 502 503 504CONFIGURATIONS = { 505 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 506 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 507 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 508 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 509 "V100": {"model_type": "vit_b"}, 510 "A100": {"model_type": "vit_h"}, 511} 512"""Best training configurations for given hardware resources. 513""" 514 515 516def train_sam_for_configuration( 517 name: str, 518 configuration: str, 519 train_loader: DataLoader, 520 val_loader: DataLoader, 521 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 522 with_segmentation_decoder: bool = True, 523 model_type: Optional[str] = None, 524 **kwargs, 525) -> None: 526 """Run training for a SAM model with the configuration for a given hardware resource. 527 528 Selects the best training settings for the given configuration. 529 The available configurations are listed in `CONFIGURATIONS`. 530 531 Args: 532 name: The name of the model to be trained. 533 The checkpoint and logs wil have this name. 534 configuration: The configuration (= name of hardware resource). 535 train_loader: The dataloader for training. 536 val_loader: The dataloader for validation. 537 checkpoint_path: Path to checkpoint for initializing the SAM model. 538 with_segmentation_decoder: Whether to train additional UNETR decoder 539 for automatic instance segmentation. 540 model_type: Over-ride the default model type. 541 This can be used to use one of the micro_sam models as starting point 542 instead of a default sam model. 543 kwargs: Additional keyword parameters that will be passed to `train_sam`. 544 """ 545 if configuration in CONFIGURATIONS: 546 train_kwargs = CONFIGURATIONS[configuration] 547 else: 548 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 549 550 if model_type is None: 551 model_type = train_kwargs.pop("model_type") 552 else: 553 expected_model_type = train_kwargs.pop("model_type") 554 if model_type[:5] != expected_model_type: 555 warnings.warn("You have specified a different model type.") 556 557 train_kwargs.update(**kwargs) 558 train_sam( 559 name=name, train_loader=train_loader, val_loader=val_loader, 560 checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder, 561 model_type=model_type, **train_kwargs 562 )
FilePath =
typing.Union[str, os.PathLike]
def
train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, **model_kwargs) -> None:
167def train_sam( 168 name: str, 169 model_type: str, 170 train_loader: DataLoader, 171 val_loader: DataLoader, 172 n_epochs: int = 100, 173 early_stopping: Optional[int] = 10, 174 n_objects_per_batch: Optional[int] = 25, 175 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 176 with_segmentation_decoder: bool = True, 177 freeze: Optional[List[str]] = None, 178 device: Optional[Union[str, torch.device]] = None, 179 lr: float = 1e-5, 180 n_sub_iteration: int = 8, 181 save_root: Optional[Union[str, os.PathLike]] = None, 182 mask_prob: float = 0.5, 183 n_iterations: Optional[int] = None, 184 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 185 scheduler_kwargs: Optional[Dict[str, Any]] = None, 186 save_every_kth_epoch: Optional[int] = None, 187 pbar_signals: Optional[QObject] = None, 188 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 189 peft_kwargs: Optional[Dict] = None, 190 ignore_warnings: bool = True, 191 verify_n_labels_in_loader: Optional[int] = 50, 192 **model_kwargs, 193) -> None: 194 """Run training for a SAM model. 195 196 Args: 197 name: The name of the model to be trained. 198 The checkpoint and logs wil have this name. 199 model_type: The type of the SAM model. 200 train_loader: The dataloader for training. 201 val_loader: The dataloader for validation. 202 n_epochs: The number of epochs to train for. 203 early_stopping: Enable early stopping after this number of epochs 204 without improvement. 205 n_objects_per_batch: The number of objects per batch used to compute 206 the loss for interative segmentation. If None all objects will be used, 207 if given objects will be randomly sub-sampled. 208 checkpoint_path: Path to checkpoint for initializing the SAM model. 209 with_segmentation_decoder: Whether to train additional UNETR decoder 210 for automatic instance segmentation. 211 freeze: Specify parts of the model that should be frozen, namely: 212 image_encoder, prompt_encoder and mask_decoder 213 By default nothing is frozen and the full model is updated. 214 device: The device to use for training. 215 lr: The learning rate. 216 n_sub_iteration: The number of iterative prompts per training iteration. 217 save_root: Optional root directory for saving the checkpoints and logs. 218 If not given the current working directory is used. 219 mask_prob: The probability for using a mask as input in a given training sub-iteration. 220 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 221 scheduler_class: The learning rate scheduler to update the learning rate. 222 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 223 scheduler_kwargs: The learning rate scheduler parameters. 224 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 225 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 226 pbar_signals: Controls for napari progress bar. 227 optimizer_class: The optimizer class. 228 By default, torch.optim.AdamW is used. 229 peft_kwargs: Keyword arguments for the PEFT wrapper class. 230 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 231 By default, 50 batches of labels are verified from the dataloaders. 232 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 233 ignore_warnings: Whether to ignore raised warnings. 234 """ 235 with _filter_warnings(ignore_warnings): 236 237 t_start = time.time() 238 239 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 240 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 241 242 device = get_device(device) 243 # Get the trainable segment anything model. 244 model, state = get_trainable_sam_model( 245 model_type=model_type, 246 device=device, 247 freeze=freeze, 248 checkpoint_path=checkpoint_path, 249 return_state=True, 250 peft_kwargs=peft_kwargs, 251 **model_kwargs 252 ) 253 254 # This class creates all the training data for a batch (inputs, prompts and labels). 255 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) 256 257 # Create the UNETR decoder (if train with it) and the optimizer. 258 if with_segmentation_decoder: 259 260 # Get the UNETR. 261 unetr = get_unetr( 262 image_encoder=model.sam.image_encoder, 263 decoder_state=state.get("decoder_state", None), 264 device=device, 265 ) 266 267 # Get the parameters for SAM and the decoder from UNETR. 268 joint_model_params = [params for params in model.parameters()] # sam parameters 269 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 270 if not param_name.startswith("encoder"): 271 joint_model_params.append(params) 272 273 optimizer = optimizer_class(joint_model_params, lr=lr) 274 275 else: 276 optimizer = optimizer_class(model.parameters(), lr=lr) 277 278 if scheduler_kwargs is None: 279 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 280 281 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 282 283 # The trainer which performs training and validation. 284 if with_segmentation_decoder: 285 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 286 trainer = joint_trainers.JointSamTrainer( 287 name=name, 288 save_root=save_root, 289 train_loader=train_loader, 290 val_loader=val_loader, 291 model=model, 292 optimizer=optimizer, 293 device=device, 294 lr_scheduler=scheduler, 295 logger=joint_trainers.JointSamLogger, 296 log_image_interval=100, 297 mixed_precision=True, 298 convert_inputs=convert_inputs, 299 n_objects_per_batch=n_objects_per_batch, 300 n_sub_iteration=n_sub_iteration, 301 compile_model=False, 302 unetr=unetr, 303 instance_loss=instance_seg_loss, 304 instance_metric=instance_seg_loss, 305 early_stopping=early_stopping, 306 mask_prob=mask_prob, 307 ) 308 else: 309 trainer = trainers.SamTrainer( 310 name=name, 311 train_loader=train_loader, 312 val_loader=val_loader, 313 model=model, 314 optimizer=optimizer, 315 device=device, 316 lr_scheduler=scheduler, 317 logger=trainers.SamLogger, 318 log_image_interval=100, 319 mixed_precision=True, 320 convert_inputs=convert_inputs, 321 n_objects_per_batch=n_objects_per_batch, 322 n_sub_iteration=n_sub_iteration, 323 compile_model=False, 324 early_stopping=early_stopping, 325 mask_prob=mask_prob, 326 save_root=save_root, 327 ) 328 329 if n_iterations is None: 330 trainer_fit_params = {"epochs": n_epochs} 331 else: 332 trainer_fit_params = {"iterations": n_iterations} 333 334 if save_every_kth_epoch is not None: 335 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 336 337 if pbar_signals is not None: 338 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 339 trainer_fit_params["progress"] = progress_bar_wrapper 340 341 trainer.fit(**trainer_fit_params) 342 343 t_run = time.time() - t_start 344 hours = int(t_run // 3600) 345 minutes = int(t_run // 60) 346 seconds = int(round(t_run % 60, 0)) 347 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training.
- lr: The learning rate.
- n_sub_iteration: The number of iterative prompts per training iteration.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration.
- n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
- scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
- scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
- save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
. - ignore_warnings: Whether to ignore raised warnings.
def
default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
380def default_sam_dataset( 381 raw_paths: Union[List[FilePath], FilePath], 382 raw_key: Optional[str], 383 label_paths: Union[List[FilePath], FilePath], 384 label_key: Optional[str], 385 patch_shape: Tuple[int], 386 with_segmentation_decoder: bool, 387 with_channels: bool = False, 388 sampler: Optional[Callable] = None, 389 raw_transform: Optional[Callable] = None, 390 n_samples: Optional[int] = None, 391 is_train: bool = True, 392 min_size: int = 25, 393 max_sampling_attempts: Optional[int] = None, 394 **kwargs, 395) -> Dataset: 396 """Create a PyTorch Dataset for training a SAM model. 397 398 Args: 399 raw_paths: The path(s) to the image data used for training. 400 Can either be multiple 2D images or volumetric data. 401 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 402 or a glob pattern for selecting multiple files. 403 label_paths: The path(s) to the label data used for training. 404 Can either be multiple 2D images or volumetric data. 405 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 406 or a glob pattern for selecting multiple files. 407 patch_shape: The shape for training patches. 408 with_segmentation_decoder: Whether to train with additional segmentation decoder. 409 with_channels: Whether the image data has RGB channels. 410 sampler: A sampler to reject batches according to a given criterion. 411 raw_transform: Transformation applied to the image data. 412 If not given the data will be cast to 8bit. 413 n_samples: The number of samples for this dataset. 414 is_train: Whether this dataset is used for training or validation. 415 min_size: Minimal object size. Smaller objects will be filtered. 416 max_sampling_attempts: Number of sampling attempts to make from a dataset. 417 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 418 419 Returns: 420 The segmentation dataset. 421 """ 422 423 # Set the data transformations. 424 if raw_transform is None: 425 raw_transform = require_8bit 426 427 if with_segmentation_decoder: 428 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 429 distances=True, boundary_distances=True, directed_distances=False, 430 foreground=True, instances=True, min_size=min_size, 431 ) 432 else: 433 label_transform = torch_em.transform.label.MinSizeLabelTransform( 434 min_size=min_size 435 ) 436 437 # Set a default sampler if none was passed. 438 if sampler is None: 439 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 440 441 # Check the patch shape to add a singleton if required. 442 patch_shape = _update_patch_shape( 443 patch_shape, raw_paths, raw_key, with_channels 444 ) 445 446 # Set a minimum number of samples per epoch. 447 if n_samples is None: 448 loader = torch_em.default_segmentation_loader( 449 raw_paths=raw_paths, 450 raw_key=raw_key, 451 label_paths=label_paths, 452 label_key=label_key, 453 batch_size=1, 454 patch_shape=patch_shape, 455 ndim=2, 456 **kwargs 457 ) 458 n_samples = max(len(loader), 100 if is_train else 5) 459 460 dataset = torch_em.default_segmentation_dataset( 461 raw_paths=raw_paths, 462 raw_key=raw_key, 463 label_paths=label_paths, 464 label_key=label_key, 465 patch_shape=patch_shape, 466 raw_transform=raw_transform, 467 label_transform=label_transform, 468 with_channels=with_channels, 469 ndim=2, 470 sampler=sampler, 471 n_samples=n_samples, 472 **kwargs, 473 ) 474 475 if max_sampling_attempts is not None: 476 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 477 for ds in dataset.datasets: 478 ds.max_sampling_attempts = max_sampling_attempts 479 else: 480 dataset.max_sampling_attempts = max_sampling_attempts 481 482 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has RGB channels.
- sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation.
- min_size: Minimal object size. Smaller objects will be filtered.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
def
default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
485def default_sam_loader(**kwargs) -> DataLoader: 486 """Create a PyTorch DataLoader for training a SAM model. 487 488 Args: 489 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 490 491 Returns: 492 The DataLoader. 493 """ 494 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 495 496 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 497 # which the users can provide to get their desired segmentation dataset. 498 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 499 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 500 501 ds = default_sam_dataset(**ds_kwargs) 502 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.
CONFIGURATIONS =
{'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
Best training configurations for given hardware resources.
def
train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
517def train_sam_for_configuration( 518 name: str, 519 configuration: str, 520 train_loader: DataLoader, 521 val_loader: DataLoader, 522 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 523 with_segmentation_decoder: bool = True, 524 model_type: Optional[str] = None, 525 **kwargs, 526) -> None: 527 """Run training for a SAM model with the configuration for a given hardware resource. 528 529 Selects the best training settings for the given configuration. 530 The available configurations are listed in `CONFIGURATIONS`. 531 532 Args: 533 name: The name of the model to be trained. 534 The checkpoint and logs wil have this name. 535 configuration: The configuration (= name of hardware resource). 536 train_loader: The dataloader for training. 537 val_loader: The dataloader for validation. 538 checkpoint_path: Path to checkpoint for initializing the SAM model. 539 with_segmentation_decoder: Whether to train additional UNETR decoder 540 for automatic instance segmentation. 541 model_type: Over-ride the default model type. 542 This can be used to use one of the micro_sam models as starting point 543 instead of a default sam model. 544 kwargs: Additional keyword parameters that will be passed to `train_sam`. 545 """ 546 if configuration in CONFIGURATIONS: 547 train_kwargs = CONFIGURATIONS[configuration] 548 else: 549 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 550 551 if model_type is None: 552 model_type = train_kwargs.pop("model_type") 553 else: 554 expected_model_type = train_kwargs.pop("model_type") 555 if model_type[:5] != expected_model_type: 556 warnings.warn("You have specified a different model type.") 557 558 train_kwargs.update(**kwargs) 559 train_sam( 560 name=name, train_loader=train_loader, val_loader=val_loader, 561 checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder, 562 model_type=model_type, **train_kwargs 563 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS
.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- configuration: The configuration (= name of hardware resource).
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam
.