micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from contextlib import contextmanager, nullcontext 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import imageio.v3 as imageio 10 11import torch 12from torch.optim import Optimizer 13from torch.optim.lr_scheduler import _LRScheduler 14from torch.utils.data import DataLoader, Dataset 15 16import torch_em 17from torch_em.data.datasets.util import split_kwargs 18 19from elf.io import open_file 20 21try: 22 from qtpy.QtCore import QObject 23except Exception: 24 QObject = Any 25 26from ..util import get_device 27from . import sam_trainer as trainers 28from ..instance_segmentation import get_unetr 29from . import joint_sam_trainer as joint_trainers 30from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit 31 32 33FilePath = Union[str, os.PathLike] 34 35 36def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 37 x, _ = next(iter(loader)) 38 39 # Raw data: check that we have 1 or 3 channels. 40 n_channels = x.shape[1] 41 if n_channels not in (1, 3): 42 raise ValueError( 43 "Invalid number of channels for the input data from the data loader. " 44 f"Expect 1 or 3 channels, got {n_channels}." 45 ) 46 47 # Raw data: check that it is between [0, 255] 48 minval, maxval = x.min(), x.max() 49 if minval < 0 or minval > 255: 50 raise ValueError( 51 "Invalid data range for the input data from the data loader. " 52 f"The input has to be in range [0, 255], but got minimum value {minval}." 53 ) 54 if maxval < 1 or maxval > 255: 55 raise ValueError( 56 "Invalid data range for the input data from the data loader. " 57 f"The input has to be in range [0, 255], but got maximum value {maxval}." 58 ) 59 60 # Target data: the check depends on whether we train with or without decoder. 61 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 62 63 def _check_instance_channel(instance_channel): 64 unique_vals = torch.unique(instance_channel) 65 if (unique_vals < 0).any(): 66 raise ValueError( 67 "The target channel with the instance segmentation must not have negative values." 68 ) 69 if len(unique_vals) == 1: 70 raise ValueError( 71 "The target channel with the instance segmentation must have at least one instance." 72 ) 73 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 74 raise ValueError( 75 "All values in the target channel with the instance segmentation must be integer." 76 ) 77 78 counter = 0 79 name = "" if name is None else f"'{name}'" 80 for x, y in tqdm( 81 loader, 82 desc=f"Verifying labels in {name} dataloader", 83 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 84 ): 85 n_channels_y = y.shape[1] 86 if with_segmentation_decoder: 87 if n_channels_y != 4: 88 raise ValueError( 89 "Invalid number of channels in the target data from the data loader. " 90 "Expect 4 channel for training with an instance segmentation decoder, " 91 f"but got {n_channels_y} channels." 92 ) 93 # Check instance channel per sample in a batch 94 for per_y_sample in y: 95 _check_instance_channel(per_y_sample[0]) 96 97 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 98 if targets_min < 0 or targets_min > 1: 99 raise ValueError( 100 "Invalid value range in the target data from the value loader. " 101 "Expect the 3 last target channels (for normalized distances and foreground probabilities)" 102 f"to be in range [0.0, 1.0], but got min {targets_min}" 103 ) 104 if targets_max < 0 or targets_max > 1: 105 raise ValueError( 106 "Invalid value range in the target data from the value loader. " 107 "Expect the 3 last target channels (for normalized distances and foreground probabilities)" 108 f"to be in range [0.0, 1.0], but got max {targets_max}" 109 ) 110 111 else: 112 if n_channels_y != 1: 113 raise ValueError( 114 "Invalid number of channels in the target data from the data loader. " 115 "Expect 1 channel for training without an instance segmentation decoder," 116 f"but got {n_channels_y} channels." 117 ) 118 # Check instance channel per sample in a batch 119 for per_y_sample in y: 120 _check_instance_channel(per_y_sample) 121 122 counter += 1 123 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 124 break 125 126 127# Make the progress bar callbacks compatible with a tqdm progress bar interface. 128class _ProgressBarWrapper: 129 def __init__(self, signals): 130 self._signals = signals 131 self._total = None 132 133 @property 134 def total(self): 135 return self._total 136 137 @total.setter 138 def total(self, value): 139 self._signals.pbar_total.emit(value) 140 self._total = value 141 142 def update(self, steps): 143 self._signals.pbar_update.emit(steps) 144 145 def set_description(self, desc, **kwargs): 146 self._signals.pbar_description.emit(desc) 147 148 149def _count_parameters(model_parameters): 150 params = sum(p.numel() for p in model_parameters if p.requires_grad) 151 params = params / 1e6 152 print(f"The number of trainable parameters for the provided model is {round(params, 2)}M") 153 154 155@contextmanager 156def _filter_warnings(ignore_warnings): 157 if ignore_warnings: 158 with warnings.catch_warnings(): 159 warnings.simplefilter("ignore") 160 yield 161 else: 162 with nullcontext(): 163 yield 164 165 166def train_sam( 167 name: str, 168 model_type: str, 169 train_loader: DataLoader, 170 val_loader: DataLoader, 171 n_epochs: int = 100, 172 early_stopping: Optional[int] = 10, 173 n_objects_per_batch: Optional[int] = 25, 174 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 175 with_segmentation_decoder: bool = True, 176 freeze: Optional[List[str]] = None, 177 device: Optional[Union[str, torch.device]] = None, 178 lr: float = 1e-5, 179 n_sub_iteration: int = 8, 180 save_root: Optional[Union[str, os.PathLike]] = None, 181 mask_prob: float = 0.5, 182 n_iterations: Optional[int] = None, 183 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 184 scheduler_kwargs: Optional[Dict[str, Any]] = None, 185 save_every_kth_epoch: Optional[int] = None, 186 pbar_signals: Optional[QObject] = None, 187 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 188 peft_kwargs: Optional[Dict] = None, 189 ignore_warnings: bool = True, 190 verify_n_labels_in_loader: Optional[int] = 50, 191 **model_kwargs, 192) -> None: 193 """Run training for a SAM model. 194 195 Args: 196 name: The name of the model to be trained. 197 The checkpoint and logs wil have this name. 198 model_type: The type of the SAM model. 199 train_loader: The dataloader for training. 200 val_loader: The dataloader for validation. 201 n_epochs: The number of epochs to train for. 202 early_stopping: Enable early stopping after this number of epochs 203 without improvement. 204 n_objects_per_batch: The number of objects per batch used to compute 205 the loss for interative segmentation. If None all objects will be used, 206 if given objects will be randomly sub-sampled. 207 checkpoint_path: Path to checkpoint for initializing the SAM model. 208 with_segmentation_decoder: Whether to train additional UNETR decoder 209 for automatic instance segmentation. 210 freeze: Specify parts of the model that should be frozen, namely: 211 image_encoder, prompt_encoder and mask_decoder 212 By default nothing is frozen and the full model is updated. 213 device: The device to use for training. 214 lr: The learning rate. 215 n_sub_iteration: The number of iterative prompts per training iteration. 216 save_root: Optional root directory for saving the checkpoints and logs. 217 If not given the current working directory is used. 218 mask_prob: The probability for using a mask as input in a given training sub-iteration. 219 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 220 scheduler_class: The learning rate scheduler to update the learning rate. 221 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 222 scheduler_kwargs: The learning rate scheduler parameters. 223 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 224 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 225 pbar_signals: Controls for napari progress bar. 226 optimizer_class: The optimizer class. 227 By default, torch.optim.AdamW is used. 228 peft_kwargs: Keyword arguments for the PEFT wrapper class. 229 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 230 By default, 50 batches of labels are verified from the dataloaders. 231 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 232 ignore_warnings: Whether to ignore raised warnings. 233 """ 234 with _filter_warnings(ignore_warnings): 235 236 t_start = time.time() 237 238 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 239 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 240 241 device = get_device(device) 242 # Get the trainable segment anything model. 243 model, state = get_trainable_sam_model( 244 model_type=model_type, 245 device=device, 246 freeze=freeze, 247 checkpoint_path=checkpoint_path, 248 return_state=True, 249 peft_kwargs=peft_kwargs, 250 **model_kwargs 251 ) 252 # This class creates all the training data for a batch (inputs, prompts and labels). 253 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) 254 255 # Create the UNETR decoder (if train with it) and the optimizer. 256 if with_segmentation_decoder: 257 258 # Get the UNETR. 259 unetr = get_unetr( 260 image_encoder=model.sam.image_encoder, 261 decoder_state=state.get("decoder_state", None), 262 device=device, 263 ) 264 265 # Get the parameters for SAM and the decoder from UNETR. 266 joint_model_params = [params for params in model.parameters()] # sam parameters 267 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 268 if not param_name.startswith("encoder"): 269 joint_model_params.append(params) 270 271 optimizer = optimizer_class(joint_model_params, lr=lr) 272 273 else: 274 optimizer = optimizer_class(model.parameters(), lr=lr) 275 276 if scheduler_kwargs is None: 277 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 278 279 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 280 281 # The trainer which performs training and validation. 282 if with_segmentation_decoder: 283 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 284 trainer = joint_trainers.JointSamTrainer( 285 name=name, 286 save_root=save_root, 287 train_loader=train_loader, 288 val_loader=val_loader, 289 model=model, 290 optimizer=optimizer, 291 device=device, 292 lr_scheduler=scheduler, 293 logger=joint_trainers.JointSamLogger, 294 log_image_interval=100, 295 mixed_precision=True, 296 convert_inputs=convert_inputs, 297 n_objects_per_batch=n_objects_per_batch, 298 n_sub_iteration=n_sub_iteration, 299 compile_model=False, 300 unetr=unetr, 301 instance_loss=instance_seg_loss, 302 instance_metric=instance_seg_loss, 303 early_stopping=early_stopping, 304 mask_prob=mask_prob, 305 ) 306 else: 307 trainer = trainers.SamTrainer( 308 name=name, 309 train_loader=train_loader, 310 val_loader=val_loader, 311 model=model, 312 optimizer=optimizer, 313 device=device, 314 lr_scheduler=scheduler, 315 logger=trainers.SamLogger, 316 log_image_interval=100, 317 mixed_precision=True, 318 convert_inputs=convert_inputs, 319 n_objects_per_batch=n_objects_per_batch, 320 n_sub_iteration=n_sub_iteration, 321 compile_model=False, 322 early_stopping=early_stopping, 323 mask_prob=mask_prob, 324 save_root=save_root, 325 ) 326 327 if n_iterations is None: 328 trainer_fit_params = {"epochs": n_epochs} 329 else: 330 trainer_fit_params = {"iterations": n_iterations} 331 332 if save_every_kth_epoch is not None: 333 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 334 335 if pbar_signals is not None: 336 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 337 trainer_fit_params["progress"] = progress_bar_wrapper 338 339 trainer.fit(**trainer_fit_params) 340 341 t_run = time.time() - t_start 342 hours = int(t_run // 3600) 343 minutes = int(t_run // 60) 344 seconds = int(round(t_run % 60, 0)) 345 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 346 347 348def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 349 if isinstance(raw_paths, (str, os.PathLike)): 350 path = raw_paths 351 else: 352 path = raw_paths[0] 353 assert isinstance(path, (str, os.PathLike)) 354 355 # Check the underlying data dimensionality. 356 if raw_key is None: # If no key is given then we assume it's an image file. 357 ndim = imageio.imread(path).ndim 358 else: # Otherwise we try to open the file from key. 359 try: # First try to open it with elf. 360 with open_file(path, "r") as f: 361 ndim = f[raw_key].ndim 362 except ValueError: # This may fail for images in a folder with different sizes. 363 # In that case we read one of the images. 364 image_path = glob(os.path.join(path, raw_key))[0] 365 ndim = imageio.imread(image_path).ndim 366 367 if ndim == 2: 368 assert len(patch_shape) == 2 369 return patch_shape 370 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 371 return (1,) + patch_shape 372 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 373 return (1,) + patch_shape 374 else: 375 return patch_shape 376 377 378def default_sam_dataset( 379 raw_paths: Union[List[FilePath], FilePath], 380 raw_key: Optional[str], 381 label_paths: Union[List[FilePath], FilePath], 382 label_key: Optional[str], 383 patch_shape: Tuple[int], 384 with_segmentation_decoder: bool, 385 with_channels: bool = False, 386 sampler: Optional[Callable] = None, 387 raw_transform: Optional[Callable] = None, 388 n_samples: Optional[int] = None, 389 is_train: bool = True, 390 min_size: int = 25, 391 max_sampling_attempts: Optional[int] = None, 392 is_seg_dataset: Optional[bool] = None, 393 **kwargs, 394) -> Dataset: 395 """Create a PyTorch Dataset for training a SAM model. 396 397 Args: 398 raw_paths: The path(s) to the image data used for training. 399 Can either be multiple 2D images or volumetric data. 400 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 401 or a glob pattern for selecting multiple files. 402 label_paths: The path(s) to the label data used for training. 403 Can either be multiple 2D images or volumetric data. 404 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 405 or a glob pattern for selecting multiple files. 406 patch_shape: The shape for training patches. 407 with_segmentation_decoder: Whether to train with additional segmentation decoder. 408 with_channels: Whether the image data has RGB channels. 409 sampler: A sampler to reject batches according to a given criterion. 410 raw_transform: Transformation applied to the image data. 411 If not given the data will be cast to 8bit. 412 n_samples: The number of samples for this dataset. 413 is_train: Whether this dataset is used for training or validation. 414 min_size: Minimal object size. Smaller objects will be filtered. 415 max_sampling_attempts: Number of sampling attempts to make from a dataset. 416 is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' 417 or 'from torch_em.data import ImageCollectionDataset' 418 419 Returns: 420 The dataset. 421 """ 422 423 # Set the data transformations. 424 if raw_transform is None: 425 raw_transform = require_8bit 426 427 if with_segmentation_decoder: 428 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 429 distances=True, boundary_distances=True, directed_distances=False, 430 foreground=True, instances=True, min_size=min_size, 431 ) 432 else: 433 label_transform = torch_em.transform.label.MinSizeLabelTransform( 434 min_size=min_size 435 ) 436 437 # Set a default sampler if none was passed. 438 if sampler is None: 439 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 440 441 # Check the patch shape to add a singleton if required. 442 patch_shape = _update_patch_shape( 443 patch_shape, raw_paths, raw_key, with_channels 444 ) 445 446 # Set a minimum number of samples per epoch. 447 if n_samples is None: 448 loader = torch_em.default_segmentation_loader( 449 raw_paths, raw_key, label_paths, label_key, batch_size=1, 450 patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset, 451 ) 452 n_samples = max(len(loader), 100 if is_train else 5) 453 454 dataset = torch_em.default_segmentation_dataset( 455 raw_paths, raw_key, label_paths, label_key, 456 patch_shape=patch_shape, 457 raw_transform=raw_transform, label_transform=label_transform, 458 with_channels=with_channels, ndim=2, 459 sampler=sampler, n_samples=n_samples, 460 is_seg_dataset=is_seg_dataset, 461 **kwargs, 462 ) 463 464 if max_sampling_attempts is not None: 465 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 466 for ds in dataset.datasets: 467 ds.max_sampling_attempts = max_sampling_attempts 468 else: 469 dataset.max_sampling_attempts = max_sampling_attempts 470 471 return dataset 472 473 474def default_sam_loader(**kwargs) -> DataLoader: 475 ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs) 476 ds = default_sam_dataset(**ds_kwargs) 477 loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 478 return loader 479 480 481CONFIGURATIONS = { 482 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 483 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 484 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 485 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 486 "V100": {"model_type": "vit_b"}, 487 "A100": {"model_type": "vit_h"}, 488} 489"""Best training configurations for given hardware resources. 490""" 491 492 493def train_sam_for_configuration( 494 name: str, 495 configuration: str, 496 train_loader: DataLoader, 497 val_loader: DataLoader, 498 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 499 with_segmentation_decoder: bool = True, 500 model_type: Optional[str] = None, 501 **kwargs, 502) -> None: 503 """Run training for a SAM model with the configuration for a given hardware resource. 504 505 Selects the best training settings for the given configuration. 506 The available configurations are listed in `CONFIGURATIONS`. 507 508 Args: 509 name: The name of the model to be trained. 510 The checkpoint and logs wil have this name. 511 configuration: The configuration (= name of hardware resource). 512 train_loader: The dataloader for training. 513 val_loader: The dataloader for validation. 514 checkpoint_path: Path to checkpoint for initializing the SAM model. 515 with_segmentation_decoder: Whether to train additional UNETR decoder 516 for automatic instance segmentation. 517 model_type: Over-ride the default model type. 518 This can be used to use one of the micro_sam models as starting point 519 instead of a default sam model. 520 kwargs: Additional keyword parameterts that will be passed to `train_sam`. 521 """ 522 if configuration in CONFIGURATIONS: 523 train_kwargs = CONFIGURATIONS[configuration] 524 else: 525 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 526 527 if model_type is None: 528 model_type = train_kwargs.pop("model_type") 529 else: 530 expected_model_type = train_kwargs.pop("model_type") 531 if model_type[:5] != expected_model_type: 532 warnings.warn("You have specified a different model type.") 533 534 train_kwargs.update(**kwargs) 535 train_sam( 536 name=name, train_loader=train_loader, val_loader=val_loader, 537 checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder, 538 model_type=model_type, **train_kwargs 539 )
FilePath =
typing.Union[str, os.PathLike]
def
train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, **model_kwargs) -> None:
167def train_sam( 168 name: str, 169 model_type: str, 170 train_loader: DataLoader, 171 val_loader: DataLoader, 172 n_epochs: int = 100, 173 early_stopping: Optional[int] = 10, 174 n_objects_per_batch: Optional[int] = 25, 175 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 176 with_segmentation_decoder: bool = True, 177 freeze: Optional[List[str]] = None, 178 device: Optional[Union[str, torch.device]] = None, 179 lr: float = 1e-5, 180 n_sub_iteration: int = 8, 181 save_root: Optional[Union[str, os.PathLike]] = None, 182 mask_prob: float = 0.5, 183 n_iterations: Optional[int] = None, 184 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 185 scheduler_kwargs: Optional[Dict[str, Any]] = None, 186 save_every_kth_epoch: Optional[int] = None, 187 pbar_signals: Optional[QObject] = None, 188 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 189 peft_kwargs: Optional[Dict] = None, 190 ignore_warnings: bool = True, 191 verify_n_labels_in_loader: Optional[int] = 50, 192 **model_kwargs, 193) -> None: 194 """Run training for a SAM model. 195 196 Args: 197 name: The name of the model to be trained. 198 The checkpoint and logs wil have this name. 199 model_type: The type of the SAM model. 200 train_loader: The dataloader for training. 201 val_loader: The dataloader for validation. 202 n_epochs: The number of epochs to train for. 203 early_stopping: Enable early stopping after this number of epochs 204 without improvement. 205 n_objects_per_batch: The number of objects per batch used to compute 206 the loss for interative segmentation. If None all objects will be used, 207 if given objects will be randomly sub-sampled. 208 checkpoint_path: Path to checkpoint for initializing the SAM model. 209 with_segmentation_decoder: Whether to train additional UNETR decoder 210 for automatic instance segmentation. 211 freeze: Specify parts of the model that should be frozen, namely: 212 image_encoder, prompt_encoder and mask_decoder 213 By default nothing is frozen and the full model is updated. 214 device: The device to use for training. 215 lr: The learning rate. 216 n_sub_iteration: The number of iterative prompts per training iteration. 217 save_root: Optional root directory for saving the checkpoints and logs. 218 If not given the current working directory is used. 219 mask_prob: The probability for using a mask as input in a given training sub-iteration. 220 n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. 221 scheduler_class: The learning rate scheduler to update the learning rate. 222 By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. 223 scheduler_kwargs: The learning rate scheduler parameters. 224 If passed None, the chosen default parameters are used in ReduceLROnPlateau. 225 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 226 pbar_signals: Controls for napari progress bar. 227 optimizer_class: The optimizer class. 228 By default, torch.optim.AdamW is used. 229 peft_kwargs: Keyword arguments for the PEFT wrapper class. 230 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 231 By default, 50 batches of labels are verified from the dataloaders. 232 model_kwargs: Additional keyword arguments for the `util.get_sam_model`. 233 ignore_warnings: Whether to ignore raised warnings. 234 """ 235 with _filter_warnings(ignore_warnings): 236 237 t_start = time.time() 238 239 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 240 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 241 242 device = get_device(device) 243 # Get the trainable segment anything model. 244 model, state = get_trainable_sam_model( 245 model_type=model_type, 246 device=device, 247 freeze=freeze, 248 checkpoint_path=checkpoint_path, 249 return_state=True, 250 peft_kwargs=peft_kwargs, 251 **model_kwargs 252 ) 253 # This class creates all the training data for a batch (inputs, prompts and labels). 254 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) 255 256 # Create the UNETR decoder (if train with it) and the optimizer. 257 if with_segmentation_decoder: 258 259 # Get the UNETR. 260 unetr = get_unetr( 261 image_encoder=model.sam.image_encoder, 262 decoder_state=state.get("decoder_state", None), 263 device=device, 264 ) 265 266 # Get the parameters for SAM and the decoder from UNETR. 267 joint_model_params = [params for params in model.parameters()] # sam parameters 268 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 269 if not param_name.startswith("encoder"): 270 joint_model_params.append(params) 271 272 optimizer = optimizer_class(joint_model_params, lr=lr) 273 274 else: 275 optimizer = optimizer_class(model.parameters(), lr=lr) 276 277 if scheduler_kwargs is None: 278 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} 279 280 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 281 282 # The trainer which performs training and validation. 283 if with_segmentation_decoder: 284 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 285 trainer = joint_trainers.JointSamTrainer( 286 name=name, 287 save_root=save_root, 288 train_loader=train_loader, 289 val_loader=val_loader, 290 model=model, 291 optimizer=optimizer, 292 device=device, 293 lr_scheduler=scheduler, 294 logger=joint_trainers.JointSamLogger, 295 log_image_interval=100, 296 mixed_precision=True, 297 convert_inputs=convert_inputs, 298 n_objects_per_batch=n_objects_per_batch, 299 n_sub_iteration=n_sub_iteration, 300 compile_model=False, 301 unetr=unetr, 302 instance_loss=instance_seg_loss, 303 instance_metric=instance_seg_loss, 304 early_stopping=early_stopping, 305 mask_prob=mask_prob, 306 ) 307 else: 308 trainer = trainers.SamTrainer( 309 name=name, 310 train_loader=train_loader, 311 val_loader=val_loader, 312 model=model, 313 optimizer=optimizer, 314 device=device, 315 lr_scheduler=scheduler, 316 logger=trainers.SamLogger, 317 log_image_interval=100, 318 mixed_precision=True, 319 convert_inputs=convert_inputs, 320 n_objects_per_batch=n_objects_per_batch, 321 n_sub_iteration=n_sub_iteration, 322 compile_model=False, 323 early_stopping=early_stopping, 324 mask_prob=mask_prob, 325 save_root=save_root, 326 ) 327 328 if n_iterations is None: 329 trainer_fit_params = {"epochs": n_epochs} 330 else: 331 trainer_fit_params = {"iterations": n_iterations} 332 333 if save_every_kth_epoch is not None: 334 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 335 336 if pbar_signals is not None: 337 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 338 trainer_fit_params["progress"] = progress_bar_wrapper 339 340 trainer.fit(**trainer_fit_params) 341 342 t_run = time.time() - t_start 343 hours = int(t_run // 3600) 344 minutes = int(t_run // 60) 345 seconds = int(round(t_run % 60, 0)) 346 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training.
- lr: The learning rate.
- n_sub_iteration: The number of iterative prompts per training iteration.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration.
- n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
- scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
- scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
- save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
- peft_kwargs: Keyword arguments for the PEFT wrapper class.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- model_kwargs: Additional keyword arguments for the
util.get_sam_model
. - ignore_warnings: Whether to ignore raised warnings.
def
default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, is_seg_dataset: Optional[bool] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
379def default_sam_dataset( 380 raw_paths: Union[List[FilePath], FilePath], 381 raw_key: Optional[str], 382 label_paths: Union[List[FilePath], FilePath], 383 label_key: Optional[str], 384 patch_shape: Tuple[int], 385 with_segmentation_decoder: bool, 386 with_channels: bool = False, 387 sampler: Optional[Callable] = None, 388 raw_transform: Optional[Callable] = None, 389 n_samples: Optional[int] = None, 390 is_train: bool = True, 391 min_size: int = 25, 392 max_sampling_attempts: Optional[int] = None, 393 is_seg_dataset: Optional[bool] = None, 394 **kwargs, 395) -> Dataset: 396 """Create a PyTorch Dataset for training a SAM model. 397 398 Args: 399 raw_paths: The path(s) to the image data used for training. 400 Can either be multiple 2D images or volumetric data. 401 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 402 or a glob pattern for selecting multiple files. 403 label_paths: The path(s) to the label data used for training. 404 Can either be multiple 2D images or volumetric data. 405 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 406 or a glob pattern for selecting multiple files. 407 patch_shape: The shape for training patches. 408 with_segmentation_decoder: Whether to train with additional segmentation decoder. 409 with_channels: Whether the image data has RGB channels. 410 sampler: A sampler to reject batches according to a given criterion. 411 raw_transform: Transformation applied to the image data. 412 If not given the data will be cast to 8bit. 413 n_samples: The number of samples for this dataset. 414 is_train: Whether this dataset is used for training or validation. 415 min_size: Minimal object size. Smaller objects will be filtered. 416 max_sampling_attempts: Number of sampling attempts to make from a dataset. 417 is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' 418 or 'from torch_em.data import ImageCollectionDataset' 419 420 Returns: 421 The dataset. 422 """ 423 424 # Set the data transformations. 425 if raw_transform is None: 426 raw_transform = require_8bit 427 428 if with_segmentation_decoder: 429 label_transform = torch_em.transform.label.PerObjectDistanceTransform( 430 distances=True, boundary_distances=True, directed_distances=False, 431 foreground=True, instances=True, min_size=min_size, 432 ) 433 else: 434 label_transform = torch_em.transform.label.MinSizeLabelTransform( 435 min_size=min_size 436 ) 437 438 # Set a default sampler if none was passed. 439 if sampler is None: 440 sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) 441 442 # Check the patch shape to add a singleton if required. 443 patch_shape = _update_patch_shape( 444 patch_shape, raw_paths, raw_key, with_channels 445 ) 446 447 # Set a minimum number of samples per epoch. 448 if n_samples is None: 449 loader = torch_em.default_segmentation_loader( 450 raw_paths, raw_key, label_paths, label_key, batch_size=1, 451 patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset, 452 ) 453 n_samples = max(len(loader), 100 if is_train else 5) 454 455 dataset = torch_em.default_segmentation_dataset( 456 raw_paths, raw_key, label_paths, label_key, 457 patch_shape=patch_shape, 458 raw_transform=raw_transform, label_transform=label_transform, 459 with_channels=with_channels, ndim=2, 460 sampler=sampler, n_samples=n_samples, 461 is_seg_dataset=is_seg_dataset, 462 **kwargs, 463 ) 464 465 if max_sampling_attempts is not None: 466 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 467 for ds in dataset.datasets: 468 ds.max_sampling_attempts = max_sampling_attempts 469 else: 470 dataset.max_sampling_attempts = max_sampling_attempts 471 472 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has RGB channels.
- sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation.
- min_size: Minimal object size. Smaller objects will be filtered.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' or 'from torch_em.data import ImageCollectionDataset'
Returns:
The dataset.
def
default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
CONFIGURATIONS =
{'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
Best training configurations for given hardware resources.
def
train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
494def train_sam_for_configuration( 495 name: str, 496 configuration: str, 497 train_loader: DataLoader, 498 val_loader: DataLoader, 499 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 500 with_segmentation_decoder: bool = True, 501 model_type: Optional[str] = None, 502 **kwargs, 503) -> None: 504 """Run training for a SAM model with the configuration for a given hardware resource. 505 506 Selects the best training settings for the given configuration. 507 The available configurations are listed in `CONFIGURATIONS`. 508 509 Args: 510 name: The name of the model to be trained. 511 The checkpoint and logs wil have this name. 512 configuration: The configuration (= name of hardware resource). 513 train_loader: The dataloader for training. 514 val_loader: The dataloader for validation. 515 checkpoint_path: Path to checkpoint for initializing the SAM model. 516 with_segmentation_decoder: Whether to train additional UNETR decoder 517 for automatic instance segmentation. 518 model_type: Over-ride the default model type. 519 This can be used to use one of the micro_sam models as starting point 520 instead of a default sam model. 521 kwargs: Additional keyword parameterts that will be passed to `train_sam`. 522 """ 523 if configuration in CONFIGURATIONS: 524 train_kwargs = CONFIGURATIONS[configuration] 525 else: 526 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 527 528 if model_type is None: 529 model_type = train_kwargs.pop("model_type") 530 else: 531 expected_model_type = train_kwargs.pop("model_type") 532 if model_type[:5] != expected_model_type: 533 warnings.warn("You have specified a different model type.") 534 535 train_kwargs.update(**kwargs) 536 train_sam( 537 name=name, train_loader=train_loader, val_loader=val_loader, 538 checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder, 539 model_type=model_type, **train_kwargs 540 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS
.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs wil have this name.
- configuration: The configuration (= name of hardware resource).
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
- model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameterts that will be passed to
train_sam
.