micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from collections import OrderedDict 7from contextlib import contextmanager, nullcontext 8from typing import Any, Callable, Dict, List, Optional, Tuple, Union 9 10import imageio.v3 as imageio 11 12import torch 13from torch.optim import Optimizer 14from torch.utils.data import random_split 15from torch.utils.data import DataLoader, Dataset 16from torch.optim.lr_scheduler import _LRScheduler 17 18import torch_em 19from torch_em.util import load_data 20from torch_em.data.datasets.util import split_kwargs 21 22from elf.io import open_file 23 24try: 25 from qtpy.QtCore import QObject 26except Exception: 27 QObject = Any 28 29from . import sam_trainer as trainers 30from ..instance_segmentation import get_unetr 31from ..models.peft_sam import ClassicalSurgery 32from . import joint_sam_trainer as joint_trainers 33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model 34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform 35 36 37FilePath = Union[str, os.PathLike] 38 39 40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 41 x, _ = next(iter(loader)) 42 43 # Raw data: check that we have 1 or 3 channels. 44 n_channels = x.shape[1] 45 if n_channels not in (1, 3): 46 raise ValueError( 47 "Invalid number of channels for the input data from the data loader. " 48 f"Expect 1 or 3 channels, got {n_channels}." 49 ) 50 51 # Raw data: check that it is between [0, 255] 52 minval, maxval = x.min(), x.max() 53 if minval < 0 or minval > 255: 54 raise ValueError( 55 "Invalid data range for the input data from the data loader. " 56 f"The input has to be in range [0, 255], but got minimum value {minval}." 57 ) 58 if maxval < 1 or maxval > 255: 59 raise ValueError( 60 "Invalid data range for the input data from the data loader. " 61 f"The input has to be in range [0, 255], but got maximum value {maxval}." 62 ) 63 64 # Target data: the check depends on whether we train with or without decoder. 65 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 66 67 def _check_instance_channel(instance_channel): 68 unique_vals = torch.unique(instance_channel) 69 if (unique_vals < 0).any(): 70 raise ValueError( 71 "The target channel with the instance segmentation must not have negative values." 72 ) 73 if len(unique_vals) == 1: 74 raise ValueError( 75 "The target channel with the instance segmentation must have at least one instance." 76 ) 77 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 78 raise ValueError( 79 "All values in the target channel with the instance segmentation must be integer." 80 ) 81 82 counter = 0 83 name = "" if name is None else f"'{name}'" 84 for x, y in tqdm( 85 loader, 86 desc=f"Verifying labels in {name} dataloader", 87 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 88 ): 89 n_channels_y = y.shape[1] 90 if with_segmentation_decoder: 91 if n_channels_y != 4: 92 raise ValueError( 93 "Invalid number of channels in the target data from the data loader. " 94 "Expect 4 channel for training with an instance segmentation decoder, " 95 f"but got {n_channels_y} channels." 96 ) 97 # Check instance channel per sample in a batch 98 for per_y_sample in y: 99 _check_instance_channel(per_y_sample[0]) 100 101 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 102 if targets_min < 0 or targets_min > 1: 103 raise ValueError( 104 "Invalid value range in the target data from the value loader. " 105 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 106 f"to be in range [0.0, 1.0], but got min {targets_min}" 107 ) 108 if targets_max < 0 or targets_max > 1: 109 raise ValueError( 110 "Invalid value range in the target data from the value loader. " 111 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 112 f"to be in range [0.0, 1.0], but got max {targets_max}" 113 ) 114 115 else: 116 if n_channels_y != 1: 117 raise ValueError( 118 "Invalid number of channels in the target data from the data loader. " 119 "Expect 1 channel for training without an instance segmentation decoder, " 120 f"but got {n_channels_y} channels." 121 ) 122 # Check instance channel per sample in a batch 123 for per_y_sample in y: 124 _check_instance_channel(per_y_sample) 125 126 counter += 1 127 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 128 break 129 130 131# Make the progress bar callbacks compatible with a tqdm progress bar interface. 132class _ProgressBarWrapper: 133 def __init__(self, signals): 134 self._signals = signals 135 self._total = None 136 137 @property 138 def total(self): 139 return self._total 140 141 @total.setter 142 def total(self, value): 143 self._signals.pbar_total.emit(value) 144 self._total = value 145 146 def update(self, steps): 147 self._signals.pbar_update.emit(steps) 148 149 def set_description(self, desc, **kwargs): 150 self._signals.pbar_description.emit(desc) 151 152 153@contextmanager 154def _filter_warnings(ignore_warnings): 155 if ignore_warnings: 156 with warnings.catch_warnings(): 157 warnings.simplefilter("ignore") 158 yield 159 else: 160 with nullcontext(): 161 yield 162 163 164def _count_parameters(model_parameters): 165 params = sum(p.numel() for p in model_parameters if p.requires_grad) 166 params = params / 1e6 167 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 168 169 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training): 171 if n_iterations is None: 172 trainer_fit_params = {"epochs": n_epochs} 173 else: 174 trainer_fit_params = {"iterations": n_iterations} 175 176 if save_every_kth_epoch is not None: 177 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 178 179 if pbar_signals is not None: 180 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 181 trainer_fit_params["progress"] = progress_bar_wrapper 182 183 # Avoid overwriting a trained model, if desired by the user. 184 trainer_fit_params["overwrite_training"] = overwrite_training 185 return trainer_fit_params 186 187 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): 189 optimizer = optimizer_class(model_params, lr=lr) 190 if scheduler_kwargs is None: 191 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3} 192 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 193 return optimizer, scheduler 194 195 196def train_sam( 197 name: str, 198 model_type: str, 199 train_loader: DataLoader, 200 val_loader: DataLoader, 201 n_epochs: int = 100, 202 early_stopping: Optional[int] = 10, 203 n_objects_per_batch: Optional[int] = 25, 204 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 205 with_segmentation_decoder: bool = True, 206 freeze: Optional[List[str]] = None, 207 device: Optional[Union[str, torch.device]] = None, 208 lr: float = 1e-5, 209 n_sub_iteration: int = 8, 210 save_root: Optional[Union[str, os.PathLike]] = None, 211 mask_prob: float = 0.5, 212 n_iterations: Optional[int] = None, 213 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 214 scheduler_kwargs: Optional[Dict[str, Any]] = None, 215 save_every_kth_epoch: Optional[int] = None, 216 pbar_signals: Optional[QObject] = None, 217 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 218 peft_kwargs: Optional[Dict] = None, 219 ignore_warnings: bool = True, 220 verify_n_labels_in_loader: Optional[int] = 50, 221 box_distortion_factor: Optional[float] = 0.025, 222 overwrite_training: bool = True, 223 **model_kwargs, 224) -> None: 225 """Run training for a SAM model. 226 227 Args: 228 name: The name of the model to be trained. The checkpoint and logs will have this name. 229 model_type: The type of the SAM model. 230 train_loader: The dataloader for training. 231 val_loader: The dataloader for validation. 232 n_epochs: The number of epochs to train for. 233 early_stopping: Enable early stopping after this number of epochs without improvement. 234 By default, the value is set to '10' epochs. 235 n_objects_per_batch: The number of objects per batch used to compute 236 the loss for interative segmentation. If None all objects will be used, 237 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 238 checkpoint_path: Path to checkpoint for initializing the SAM model. 239 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 240 By default, trains with the additional instance segmentation decoder. 241 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 242 By default nothing is frozen and the full model is updated. 243 device: The device to use for training. By default, automatically chooses the best available device to train. 244 lr: The learning rate. By default, set to '1e-5'. 245 n_sub_iteration: The number of iterative prompts per training iteration. 246 By default, the number of iterations is set to '8'. 247 save_root: Optional root directory for saving the checkpoints and logs. 248 If not given the current working directory is used. 249 mask_prob: The probability for using a mask as input in a given training sub-iteration. 250 By default, set to '0.5'. 251 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 252 scheduler_class: The learning rate scheduler to update the learning rate. 253 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 254 scheduler_kwargs: The learning rate scheduler parameters. 255 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 256 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 257 pbar_signals: Controls for napari progress bar. 258 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 259 peft_kwargs: Keyword arguments for the PEFT wrapper class. 260 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 261 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 262 By default, 50 batches of labels are verified from the dataloaders. 263 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 264 By default, the distortion factor is set to '0.025'. 265 overwrite_training: Whether to overwrite the trained model stored at the same location. 266 By default, overwrites the trained model at each run. 267 If set to 'False', it will avoid retraining the model if the previous run was completed. 268 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 269 """ 270 with _filter_warnings(ignore_warnings): 271 272 t_start = time.time() 273 274 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 275 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 276 277 device = get_device(device) 278 # Get the trainable segment anything model. 279 model, state = get_trainable_sam_model( 280 model_type=model_type, 281 device=device, 282 freeze=freeze, 283 checkpoint_path=checkpoint_path, 284 return_state=True, 285 peft_kwargs=peft_kwargs, 286 **model_kwargs 287 ) 288 289 # This class creates all the training data for a batch (inputs, prompts and labels). 290 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 291 292 # Create the UNETR decoder (if train with it) and the optimizer. 293 if with_segmentation_decoder: 294 295 # Get the UNETR. 296 unetr = get_unetr( 297 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 298 ) 299 300 # Get the parameters for SAM and the decoder from UNETR. 301 joint_model_params = [params for params in model.parameters()] # sam parameters 302 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 303 if not param_name.startswith("encoder"): 304 joint_model_params.append(params) 305 306 model_params = joint_model_params 307 else: 308 model_params = model.parameters() 309 310 optimizer, scheduler = _get_optimizer_and_scheduler( 311 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 312 ) 313 314 # The trainer which performs training and validation. 315 if with_segmentation_decoder: 316 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 317 trainer = joint_trainers.JointSamTrainer( 318 name=name, 319 save_root=save_root, 320 train_loader=train_loader, 321 val_loader=val_loader, 322 model=model, 323 optimizer=optimizer, 324 device=device, 325 lr_scheduler=scheduler, 326 logger=joint_trainers.JointSamLogger, 327 log_image_interval=100, 328 mixed_precision=True, 329 convert_inputs=convert_inputs, 330 n_objects_per_batch=n_objects_per_batch, 331 n_sub_iteration=n_sub_iteration, 332 compile_model=False, 333 unetr=unetr, 334 instance_loss=instance_seg_loss, 335 instance_metric=instance_seg_loss, 336 early_stopping=early_stopping, 337 mask_prob=mask_prob, 338 ) 339 else: 340 trainer = trainers.SamTrainer( 341 name=name, 342 train_loader=train_loader, 343 val_loader=val_loader, 344 model=model, 345 optimizer=optimizer, 346 device=device, 347 lr_scheduler=scheduler, 348 logger=trainers.SamLogger, 349 log_image_interval=100, 350 mixed_precision=True, 351 convert_inputs=convert_inputs, 352 n_objects_per_batch=n_objects_per_batch, 353 n_sub_iteration=n_sub_iteration, 354 compile_model=False, 355 early_stopping=early_stopping, 356 mask_prob=mask_prob, 357 save_root=save_root, 358 ) 359 360 trainer_fit_params = _get_trainer_fit_params( 361 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 362 ) 363 trainer.fit(**trainer_fit_params) 364 365 t_run = time.time() - t_start 366 hours = int(t_run // 3600) 367 minutes = int(t_run // 60) 368 seconds = int(round(t_run % 60, 0)) 369 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 370 371 372def export_instance_segmentation_model( 373 trained_model_path: Union[str, os.PathLike], 374 output_path: Union[str, os.PathLike], 375 model_type: str, 376 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 377) -> None: 378 """Export a model trained for instance segmentation with `train_instance_segmentation`. 379 380 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 381 It should only be used for automatic segmentation and may not work well for interactive segmentation. 382 383 Args: 384 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 385 output_path: The path where the exported model will be saved. 386 model_type: The model type. 387 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 388 """ 389 trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] 390 391 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 392 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 393 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 394 395 # Load the original state of the model that was used as the basis of instance segmentation training. 396 _, model_state = get_sam_model( 397 model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", 398 ) 399 # Remove the sam prefix if it's in the model state. 400 prefix = "sam." 401 model_state = OrderedDict( 402 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 403 ) 404 405 # Replace the image encoder state. 406 model_state = OrderedDict( 407 [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) 408 for k, v in model_state.items()] 409 ) 410 411 save_state = {"model_state": model_state, "decoder_state": decoder_state} 412 torch.save(save_state, output_path) 413 414 415def train_instance_segmentation( 416 name: str, 417 model_type: str, 418 train_loader: DataLoader, 419 val_loader: DataLoader, 420 n_epochs: int = 100, 421 early_stopping: Optional[int] = 10, 422 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 423 metric: Optional[torch.nn.Module] = None, 424 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 425 freeze: Optional[List[str]] = None, 426 device: Optional[Union[str, torch.device]] = None, 427 lr: float = 1e-5, 428 save_root: Optional[Union[str, os.PathLike]] = None, 429 n_iterations: Optional[int] = None, 430 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 431 scheduler_kwargs: Optional[Dict[str, Any]] = None, 432 save_every_kth_epoch: Optional[int] = None, 433 pbar_signals: Optional[QObject] = None, 434 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 435 peft_kwargs: Optional[Dict] = None, 436 ignore_warnings: bool = True, 437 overwrite_training: bool = True, 438 **model_kwargs, 439) -> None: 440 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 441 442 This setting corresponds to training a SAM model with an instance segmentation decoder, 443 without training the model parts for interactive segmentation, 444 i.e. without training the prompt encoder and mask decoder. 445 446 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 447 will not be compatible with the micro_sam functionality. 448 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 449 in a format that is compatible with micro_sam functionality. 450 Note that the exported model should only be used for automatic segmentation via AIS. 451 452 Args: 453 name: The name of the model to be trained. The checkpoint and logs will have this name. 454 model_type: The type of the SAM model. 455 train_loader: The dataloader for training. 456 val_loader: The dataloader for validation. 457 n_epochs: The number of epochs to train for. 458 early_stopping: Enable early stopping after this number of epochs without improvement. 459 By default, the value is set to '10' epochs. 460 loss: The loss function to train the instance segmentation model. 461 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 462 metric: The metric for the instance segmentation training. 463 By default the loss function is used as the metric. 464 checkpoint_path: Path to checkpoint for initializing the SAM model. 465 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 466 By default nothing is frozen and the full model is updated. 467 device: The device to use for training. By default, automatically chooses the best available device to train. 468 lr: The learning rate. By default, set to '1e-5'. 469 save_root: Optional root directory for saving the checkpoints and logs. 470 If not given the current working directory is used. 471 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 472 scheduler_class: The learning rate scheduler to update the learning rate. 473 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 474 scheduler_kwargs: The learning rate scheduler parameters. 475 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 476 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 477 pbar_signals: Controls for napari progress bar. 478 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 479 peft_kwargs: Keyword arguments for the PEFT wrapper class. 480 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 481 overwrite_training: Whether to overwrite the trained model stored at the same location. 482 By default, overwrites the trained model at each run. 483 If set to 'False', it will avoid retraining the model if the previous run was completed. 484 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 485 """ 486 487 with _filter_warnings(ignore_warnings): 488 t_start = time.time() 489 490 sam_model, state = get_trainable_sam_model( 491 model_type=model_type, 492 device=device, 493 checkpoint_path=checkpoint_path, 494 return_state=True, 495 peft_kwargs=peft_kwargs, 496 freeze=freeze, 497 **model_kwargs 498 ) 499 device = get_device(device) 500 model = get_unetr( 501 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 502 ) 503 504 optimizer, scheduler = _get_optimizer_and_scheduler( 505 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 506 ) 507 trainer = torch_em.trainer.DefaultTrainer( 508 name=name, 509 model=model, 510 train_loader=train_loader, 511 val_loader=val_loader, 512 device=device, 513 mixed_precision=True, 514 log_image_interval=50, 515 compile_model=False, 516 save_root=save_root, 517 loss=loss, 518 metric=loss if metric is None else metric, 519 optimizer=optimizer, 520 lr_scheduler=scheduler, 521 early_stopping=early_stopping, 522 ) 523 524 trainer_fit_params = _get_trainer_fit_params( 525 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 526 ) 527 trainer.fit(**trainer_fit_params) 528 529 t_run = time.time() - t_start 530 hours = int(t_run // 3600) 531 minutes = int(t_run // 60) 532 seconds = int(round(t_run % 60, 0)) 533 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 534 535 536def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 537 if isinstance(raw_paths, (str, os.PathLike)): 538 path = raw_paths 539 else: 540 path = raw_paths[0] 541 assert isinstance(path, (str, os.PathLike)) 542 543 # Check the underlying data dimensionality. 544 if raw_key is None: # If no key is given then we assume it's an image file. 545 ndim = imageio.imread(path).ndim 546 else: # Otherwise we try to open the file from key. 547 try: # First try to open it with elf. 548 with open_file(path, "r") as f: 549 ndim = f[raw_key].ndim 550 except ValueError: # This may fail for images in a folder with different sizes. 551 # In that case we read one of the images. 552 image_path = glob(os.path.join(path, raw_key))[0] 553 ndim = imageio.imread(image_path).ndim 554 555 if not isinstance(patch_shape, tuple): 556 patch_shape = tuple(patch_shape) 557 558 if ndim == 2: 559 assert len(patch_shape) == 2 560 return patch_shape 561 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 562 return (1,) + patch_shape 563 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 564 return (1,) + patch_shape 565 else: 566 return patch_shape 567 568 569def default_sam_dataset( 570 raw_paths: Union[List[FilePath], FilePath], 571 raw_key: Optional[str], 572 label_paths: Union[List[FilePath], FilePath], 573 label_key: Optional[str], 574 patch_shape: Tuple[int], 575 with_segmentation_decoder: bool, 576 with_channels: Optional[bool] = None, 577 train_instance_segmentation_only: bool = False, 578 sampler: Optional[Callable] = None, 579 raw_transform: Optional[Callable] = None, 580 n_samples: Optional[int] = None, 581 is_train: bool = True, 582 min_size: int = 25, 583 max_sampling_attempts: Optional[int] = None, 584 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 585 **kwargs, 586) -> Dataset: 587 """Create a PyTorch Dataset for training a SAM model. 588 589 Args: 590 raw_paths: The path(s) to the image data used for training. 591 Can either be multiple 2D images or volumetric data. 592 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 593 or a glob pattern for selecting multiple files. 594 label_paths: The path(s) to the label data used for training. 595 Can either be multiple 2D images or volumetric data. 596 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 597 or a glob pattern for selecting multiple files. 598 patch_shape: The shape for training patches. 599 with_segmentation_decoder: Whether to train with additional segmentation decoder. 600 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 601 train_instance_segmentation_only: Set this argument to True in order to 602 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 603 sampler: A sampler to reject batches according to a given criterion. 604 raw_transform: Transformation applied to the image data. 605 If not given the data will be cast to 8bit. 606 n_samples: The number of samples for this dataset. 607 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 608 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 609 max_sampling_attempts: Number of sampling attempts to make from a dataset. 610 rois: The region of interest(s) for the data. 611 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 612 613 Returns: 614 The segmentation dataset. 615 """ 616 617 # Check if this dataset should be used for instance segmentation only training. 618 # If yes, we set return_instances to False, since the instance channel must not 619 # be passed for this training mode. 620 return_instances = True 621 if train_instance_segmentation_only: 622 if not with_segmentation_decoder: 623 raise ValueError( 624 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 625 ) 626 return_instances = False 627 628 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 629 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 630 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 631 # and we do not set a default sampler. 632 if sampler is None and not train_instance_segmentation_only: 633 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 634 635 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 636 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 637 638 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 639 # Get valid raw paths to make checks possible. 640 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 641 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 642 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 643 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 644 645 # Load one of the raw inputs to validate whether it is RGB or not. 646 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 647 if test_raw_inputs.ndim == 3: 648 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 649 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 650 # We need to provide a list of inputs to 'ImageCollectionDataset'. 651 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 652 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 653 654 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 655 with_channels = False if with_channels is None else with_channels 656 657 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 658 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 659 with_channels = True if with_channels is None else with_channels 660 661 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 662 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 663 with_channels = False if with_channels is None else with_channels 664 665 # Set the data transformations. 666 if raw_transform is None: 667 raw_transform = require_8bit 668 669 # Prepare the label transform. 670 if with_segmentation_decoder: 671 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 672 distances=True, 673 boundary_distances=True, 674 directed_distances=False, 675 foreground=True, 676 instances=return_instances, 677 min_size=min_size, 678 ) 679 else: 680 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 681 682 # Allow combining label transforms. 683 custom_label_transform = kwargs.pop("label_transform", None) 684 if custom_label_transform is None: 685 label_transform = default_label_transform 686 else: 687 label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform) 688 689 # Check the patch shape to add a singleton if required. 690 patch_shape = _update_patch_shape( 691 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 692 ) 693 694 # Set a minimum number of samples per epoch. 695 if n_samples is None: 696 loader = torch_em.default_segmentation_loader( 697 raw_paths=raw_paths, 698 raw_key=raw_key, 699 label_paths=label_paths, 700 label_key=label_key, 701 batch_size=1, 702 patch_shape=patch_shape, 703 with_channels=with_channels, 704 ndim=2, 705 is_seg_dataset=is_seg_dataset, 706 raw_transform=raw_transform, 707 rois=rois, 708 **kwargs 709 ) 710 n_samples = max(len(loader), 100 if is_train else 5) 711 712 dataset = torch_em.default_segmentation_dataset( 713 raw_paths=raw_paths, 714 raw_key=raw_key, 715 label_paths=label_paths, 716 label_key=label_key, 717 patch_shape=patch_shape, 718 raw_transform=raw_transform, 719 label_transform=label_transform, 720 with_channels=with_channels, 721 ndim=2, 722 sampler=sampler, 723 n_samples=n_samples, 724 is_seg_dataset=is_seg_dataset, 725 rois=rois, 726 **kwargs, 727 ) 728 729 if max_sampling_attempts is not None: 730 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 731 for ds in dataset.datasets: 732 ds.max_sampling_attempts = max_sampling_attempts 733 else: 734 dataset.max_sampling_attempts = max_sampling_attempts 735 736 return dataset 737 738 739def default_sam_loader(**kwargs) -> DataLoader: 740 """Create a PyTorch DataLoader for training a SAM model. 741 742 Args: 743 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 744 745 Returns: 746 The DataLoader. 747 """ 748 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 749 750 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 751 # which the users can provide to get their desired segmentation dataset. 752 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 753 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 754 755 ds = default_sam_dataset(**ds_kwargs) 756 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 757 758 759CONFIGURATIONS = { 760 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 761 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 762 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 763 "gtx3080": { 764 "model_type": "vit_b", "n_objects_per_batch": 5, 765 "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery} 766 }, 767 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 768 "V100": {"model_type": "vit_b"}, 769 "A100": {"model_type": "vit_h"}, 770} 771"""Best training configurations for given hardware resources. 772""" 773 774 775def _find_best_configuration(): 776 if torch.cuda.is_available(): 777 778 # Check how much memory we have and select the best matching GPU 779 # for the available VRAM size. 780 _, vram = torch.cuda.mem_get_info() 781 vram = vram / 1e9 # in GB 782 783 # Maybe we can get more configurations in the future. 784 if vram > 80: # More than 80 GB: use the A100 configurations. 785 return "A100" 786 elif vram > 30: # More than 30 GB: use the V100 configurations. 787 return "V100" 788 elif vram > 14: # More than 14 GB: use the RTX5000 configurations. 789 return "rtx5000" 790 elif vram > 8: # More than 8 GB: use the GTX3080 configurations. 791 return "gtx3080" 792 else: # Otherwise: not enough memory to train on the GPU, use CPU instead. 793 return "CPU" 794 else: 795 return "CPU" 796 797 798def train_sam_for_configuration( 799 name: str, 800 train_loader: DataLoader, 801 val_loader: DataLoader, 802 configuration: Optional[str] = None, 803 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 804 with_segmentation_decoder: bool = True, 805 train_instance_segmentation_only: bool = False, 806 model_type: Optional[str] = None, 807 **kwargs, 808) -> None: 809 """Run training for a SAM model with the configuration for a given hardware resource. 810 811 Selects the best training settings for the given configuration. 812 The available configurations are listed in `CONFIGURATIONS`. 813 814 Args: 815 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 816 train_loader: The dataloader for training. 817 val_loader: The dataloader for validation. 818 configuration: The configuration (= name of hardware resource). 819 By default, it is automatically selected for the best VRAM combination. 820 checkpoint_path: Path to checkpoint for initializing the SAM model. 821 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 822 By default, trains with the additional instance segmentation decoder. 823 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 824 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 825 model_type: Over-ride the default model type. 826 This can be used to use one of the micro_sam models as starting point 827 instead of a default sam model. 828 kwargs: Additional keyword parameters that will be passed to `train_sam`. 829 """ 830 if configuration is None: # Automatically choose based on available VRAM combination. 831 configuration = _find_best_configuration() 832 833 if configuration in CONFIGURATIONS: 834 train_kwargs = CONFIGURATIONS[configuration] 835 else: 836 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 837 838 if model_type is None: 839 model_type = train_kwargs.pop("model_type") 840 else: 841 expected_model_type = train_kwargs.pop("model_type") 842 if model_type[:5] != expected_model_type: 843 warnings.warn("You have specified a different model type.") 844 845 train_kwargs.update(**kwargs) 846 if train_instance_segmentation_only: 847 train_instance_segmentation( 848 name=name, 849 train_loader=train_loader, 850 val_loader=val_loader, 851 checkpoint_path=checkpoint_path, 852 with_segmentation_decoder=with_segmentation_decoder, 853 model_type=model_type, 854 **train_kwargs 855 ) 856 else: 857 train_sam( 858 name=name, 859 train_loader=train_loader, 860 val_loader=val_loader, 861 checkpoint_path=checkpoint_path, 862 with_segmentation_decoder=with_segmentation_decoder, 863 model_type=model_type, 864 **train_kwargs 865 ) 866 867 868def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): 869 870 # Whether the model is stored in the current working directory or in another location. 871 if save_root is None: 872 save_root = os.getcwd() # Map this to current working directory, if not specified by the user. 873 874 # Get the 'best' model checkpoint ready for export. 875 best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") 876 if not os.path.exists(best_checkpoint): 877 raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.") 878 879 # Export the model if an output path has been given. 880 if output_path: 881 882 # If the filepath has a pytorch-specific ending, then we just export the checkpoint. 883 if os.path.splitext(output_path)[1] in (".pt", ".pth"): 884 export_custom_sam_model( 885 checkpoint_path=best_checkpoint, 886 model_type=model_type[:5], 887 save_path=output_path, 888 with_segmentation_decoder=with_segmentation_decoder, 889 ) 890 891 # Otherwise we export it as bioimage.io model. 892 else: 893 from micro_sam.bioimageio import export_sam_model 894 895 # Load image and corresponding labels from the val loader. 896 with torch.no_grad(): 897 image_data, label_data = next(iter(val_loader)) 898 image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze() 899 900 # Select the first channel of the label image if we have a channel axis, i.e. contains the labels 901 if label_data.ndim == 3: 902 label_data = label_data[0] # Gets the channel with instances. 903 assert image_data.shape == label_data.shape 904 label_data = label_data.astype("uint32") 905 906 export_sam_model( 907 image=image_data, 908 label_image=label_data, 909 model_type=model_type[:5], 910 name=checkpoint_name, 911 output_path=output_path, 912 checkpoint_path=best_checkpoint, 913 ) 914 915 # The final path where the model has been stored. 916 final_path = output_path 917 918 else: # If no exports have been made, inform the user about the best checkpoint. 919 final_path = best_checkpoint 920 921 return final_path 922 923 924def _parse_segmentation_decoder(segmentation_decoder): 925 if segmentation_decoder in ("None", "none"): 926 with_segmentation_decoder, train_instance_segmentation_only = False, False 927 elif segmentation_decoder == "instances": 928 with_segmentation_decoder, train_instance_segmentation_only = True, False 929 elif segmentation_decoder == "instances_only": 930 with_segmentation_decoder, train_instance_segmentation_only = True, True 931 else: 932 raise ValueError( 933 "The 'segmentation_decoder' argument currently supports the values:\n" 934 f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}." 935 ) 936 return with_segmentation_decoder, train_instance_segmentation_only 937 938 939def main(): 940 """@private""" 941 import argparse 942 943 available_models = list(get_model_names()) 944 available_models = ", ".join(available_models) 945 946 available_configurations = list(CONFIGURATIONS.keys()) 947 available_configurations = ", ".join(available_configurations) 948 949 parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.") 950 951 # Images and labels for training. 952 parser.add_argument( 953 "--images", required=True, type=str, nargs="*", 954 help="Filepath to images or the directory where the image data is stored." 955 ) 956 parser.add_argument( 957 "--labels", required=True, type=str, nargs="*", 958 help="Filepath to ground-truth labels or the directory where the label data is stored." 959 ) 960 parser.add_argument( 961 "--image_key", type=str, default=None, 962 help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. " 963 ) 964 parser.add_argument( 965 "--label_key", type=str, default=None, 966 help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. " 967 ) 968 969 # Images and labels for validation. 970 # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided. 971 # Users can choose to have their explicit validation set via this feature as well. 972 parser.add_argument( 973 "--val_images", type=str, nargs="*", 974 help="Filepath to images for validation or the directory where the image data is stored." 975 ) 976 parser.add_argument( 977 "--val_labels", type=str, nargs="*", 978 help="Filepath to ground-truth labels for validation or the directory where the label data is stored." 979 ) 980 parser.add_argument( 981 "--val_image_key", type=str, default=None, 982 help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file." 983 ) 984 parser.add_argument( 985 "--val_label_key", type=str, default=None, 986 help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file." 987 ) 988 989 # Other necessary stuff for training. 990 parser.add_argument( 991 "--configuration", type=str, default=_find_best_configuration(), 992 help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}." 993 ) 994 995 def none_or_str(value): 996 if value.lower() == 'none': 997 return None 998 return value 999 1000 # This could be extended to train for semantic segmentation or other options. 1001 parser.add_argument( 1002 "--segmentation_decoder", type=none_or_str, default="instances", 1003 help="Whether to finetune Segment Anything Model with an additional segmentation decoder. " 1004 "The following options are possible:\n" 1005 "- 'instances' to train with an additional decoder for automatic instance segmentation. " 1006 " This option enables using the automatic instance segmentation (AIS) mode.\n" 1007 "- 'instances_only' to train only the instance segmentation decoder. " 1008 " In this case the parts of SAM that are used for interactive segmentation will not be trained.\n" 1009 "- 'None' to train without an additional segmentation decoder." 1010 " This options trains only the parts of the original SAM.\n" 1011 "By default the option 'instances' is used." 1012 ) 1013 1014 # Optional advanced settings a user can opt to change the values for. 1015 parser.add_argument( 1016 "-d", "--device", type=str, default=None, 1017 help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). " 1018 "By default the most performant available device will be selected." 1019 ) 1020 parser.add_argument( 1021 "--patch_shape", type=int, nargs="*", default=(512, 512), 1022 help="The choice of patch shape for training Segment Anything Model. " 1023 "By default, a patch size of 512x512 is used." 1024 ) 1025 parser.add_argument( 1026 "-m", "--model_type", type=str, default=None, 1027 help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}." 1028 ) 1029 parser.add_argument( 1030 "--checkpoint_path", type=str, default=None, 1031 help="Checkpoint from which the SAM model will be loaded for finetuning." 1032 ) 1033 parser.add_argument( 1034 "-s", "--save_root", type=str, default=None, 1035 help="The directory where the trained models and corresponding logs will be stored. " 1036 "By default, there are stored in your current working directory." 1037 ) 1038 parser.add_argument( 1039 "--trained_model_name", type=str, default="sam_model", 1040 help="The custom name of trained model sub-folder. Allows users to have several trained models " 1041 "under the same 'save_root'." 1042 ) 1043 parser.add_argument( 1044 "--output_path", type=str, default=None, 1045 help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model." 1046 ) 1047 parser.add_argument( 1048 "--n_epochs", type=int, default=100, 1049 help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs." 1050 ) 1051 parser.add_argument( 1052 "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders." 1053 ) 1054 parser.add_argument( 1055 "--batch_size", type=int, default=1, 1056 help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1." 1057 ) 1058 parser.add_argument( 1059 "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"), 1060 help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images " 1061 "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'." 1062 ) 1063 1064 args = parser.parse_args() 1065 1066 # 1. Get all necessary stuff for training. 1067 checkpoint_name = args.trained_model_name 1068 config = args.configuration 1069 model_type = args.model_type 1070 checkpoint_path = args.checkpoint_path 1071 batch_size = args.batch_size 1072 patch_shape = args.patch_shape 1073 epochs = args.n_epochs 1074 num_workers = args.num_workers 1075 device = args.device 1076 save_root = args.save_root 1077 output_path = args.output_path 1078 with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder) 1079 1080 # Get image paths and corresponding keys. 1081 train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key 1082 val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key # noqa 1083 1084 # 2. Prepare the dataloaders. 1085 1086 # If the user wants to preprocess the inputs, we allow the possibility to do so. 1087 _raw_transform = get_raw_transform(args.preprocess) 1088 1089 # Get the dataset with files for training. 1090 dataset = default_sam_dataset( 1091 raw_paths=train_images, 1092 raw_key=train_image_key, 1093 label_paths=train_gt, 1094 label_key=train_gt_key, 1095 patch_shape=patch_shape, 1096 with_segmentation_decoder=with_segmentation_decoder, 1097 raw_transform=_raw_transform, 1098 train_instance_segmentation_only=train_instance_segmentation_only, 1099 ) 1100 1101 # If val images are not exclusively provided, we create a val split from the training data. 1102 if val_images is None: 1103 assert val_gt is None and val_image_key is None and val_gt_key is None 1104 # Use 10% of the dataset for validation - at least one image - for validation. 1105 n_val = max(1, int(0.1 * len(dataset))) 1106 train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val]) 1107 1108 else: # If val images provided, we create a new dataset for it. 1109 train_dataset = dataset 1110 val_dataset = default_sam_dataset( 1111 raw_paths=val_images, 1112 raw_key=val_image_key, 1113 label_paths=val_gt, 1114 label_key=val_gt_key, 1115 patch_shape=patch_shape, 1116 with_segmentation_decoder=with_segmentation_decoder, 1117 train_instance_segmentation_only=train_instance_segmentation_only, 1118 raw_transform=_raw_transform, 1119 ) 1120 1121 # Get the dataloaders from the datasets. 1122 train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 1123 val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers) 1124 1125 # 3. Train the Segment Anything Model. 1126 1127 # Get a valid model and other necessary parameters for training. 1128 if model_type is not None and model_type not in available_models: 1129 raise ValueError(f"'{model_type}' is not a valid choice of model.") 1130 if config is not None and config not in available_configurations: 1131 raise ValueError(f"'{config}' is not a valid choice of configuration.") 1132 1133 if model_type is None: # If user does not specify the model, we use the default model corresponding to the config. 1134 model_type = CONFIGURATIONS[config]["model_type"] 1135 1136 train_sam_for_configuration( 1137 name=checkpoint_name, 1138 configuration=config, 1139 model_type=model_type, 1140 train_loader=train_loader, 1141 val_loader=val_loader, 1142 n_epochs=epochs, 1143 checkpoint_path=checkpoint_path, 1144 with_segmentation_decoder=with_segmentation_decoder, 1145 freeze=None, # TODO: Allow for PEFT. 1146 device=device, 1147 save_root=save_root, 1148 peft_kwargs=None, # TODO: Allow for PEFT. 1149 train_instance_segmentation_only=train_instance_segmentation_only, 1150 ) 1151 1152 # 4. Export the model, if desired by the user 1153 if train_instance_segmentation_only and output_path: 1154 trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt") 1155 export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path) 1156 final_path = output_path 1157 else: 1158 final_path = _export_helper( 1159 save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader, 1160 ) 1161 1162 print(f"Training has finished. The trained model is saved at {final_path}.")
197def train_sam( 198 name: str, 199 model_type: str, 200 train_loader: DataLoader, 201 val_loader: DataLoader, 202 n_epochs: int = 100, 203 early_stopping: Optional[int] = 10, 204 n_objects_per_batch: Optional[int] = 25, 205 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 206 with_segmentation_decoder: bool = True, 207 freeze: Optional[List[str]] = None, 208 device: Optional[Union[str, torch.device]] = None, 209 lr: float = 1e-5, 210 n_sub_iteration: int = 8, 211 save_root: Optional[Union[str, os.PathLike]] = None, 212 mask_prob: float = 0.5, 213 n_iterations: Optional[int] = None, 214 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 215 scheduler_kwargs: Optional[Dict[str, Any]] = None, 216 save_every_kth_epoch: Optional[int] = None, 217 pbar_signals: Optional[QObject] = None, 218 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 219 peft_kwargs: Optional[Dict] = None, 220 ignore_warnings: bool = True, 221 verify_n_labels_in_loader: Optional[int] = 50, 222 box_distortion_factor: Optional[float] = 0.025, 223 overwrite_training: bool = True, 224 **model_kwargs, 225) -> None: 226 """Run training for a SAM model. 227 228 Args: 229 name: The name of the model to be trained. The checkpoint and logs will have this name. 230 model_type: The type of the SAM model. 231 train_loader: The dataloader for training. 232 val_loader: The dataloader for validation. 233 n_epochs: The number of epochs to train for. 234 early_stopping: Enable early stopping after this number of epochs without improvement. 235 By default, the value is set to '10' epochs. 236 n_objects_per_batch: The number of objects per batch used to compute 237 the loss for interative segmentation. If None all objects will be used, 238 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 239 checkpoint_path: Path to checkpoint for initializing the SAM model. 240 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 241 By default, trains with the additional instance segmentation decoder. 242 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 243 By default nothing is frozen and the full model is updated. 244 device: The device to use for training. By default, automatically chooses the best available device to train. 245 lr: The learning rate. By default, set to '1e-5'. 246 n_sub_iteration: The number of iterative prompts per training iteration. 247 By default, the number of iterations is set to '8'. 248 save_root: Optional root directory for saving the checkpoints and logs. 249 If not given the current working directory is used. 250 mask_prob: The probability for using a mask as input in a given training sub-iteration. 251 By default, set to '0.5'. 252 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 253 scheduler_class: The learning rate scheduler to update the learning rate. 254 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 255 scheduler_kwargs: The learning rate scheduler parameters. 256 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 257 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 258 pbar_signals: Controls for napari progress bar. 259 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 260 peft_kwargs: Keyword arguments for the PEFT wrapper class. 261 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 262 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 263 By default, 50 batches of labels are verified from the dataloaders. 264 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 265 By default, the distortion factor is set to '0.025'. 266 overwrite_training: Whether to overwrite the trained model stored at the same location. 267 By default, overwrites the trained model at each run. 268 If set to 'False', it will avoid retraining the model if the previous run was completed. 269 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 270 """ 271 with _filter_warnings(ignore_warnings): 272 273 t_start = time.time() 274 275 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 276 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 277 278 device = get_device(device) 279 # Get the trainable segment anything model. 280 model, state = get_trainable_sam_model( 281 model_type=model_type, 282 device=device, 283 freeze=freeze, 284 checkpoint_path=checkpoint_path, 285 return_state=True, 286 peft_kwargs=peft_kwargs, 287 **model_kwargs 288 ) 289 290 # This class creates all the training data for a batch (inputs, prompts and labels). 291 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 292 293 # Create the UNETR decoder (if train with it) and the optimizer. 294 if with_segmentation_decoder: 295 296 # Get the UNETR. 297 unetr = get_unetr( 298 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 299 ) 300 301 # Get the parameters for SAM and the decoder from UNETR. 302 joint_model_params = [params for params in model.parameters()] # sam parameters 303 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 304 if not param_name.startswith("encoder"): 305 joint_model_params.append(params) 306 307 model_params = joint_model_params 308 else: 309 model_params = model.parameters() 310 311 optimizer, scheduler = _get_optimizer_and_scheduler( 312 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 313 ) 314 315 # The trainer which performs training and validation. 316 if with_segmentation_decoder: 317 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 318 trainer = joint_trainers.JointSamTrainer( 319 name=name, 320 save_root=save_root, 321 train_loader=train_loader, 322 val_loader=val_loader, 323 model=model, 324 optimizer=optimizer, 325 device=device, 326 lr_scheduler=scheduler, 327 logger=joint_trainers.JointSamLogger, 328 log_image_interval=100, 329 mixed_precision=True, 330 convert_inputs=convert_inputs, 331 n_objects_per_batch=n_objects_per_batch, 332 n_sub_iteration=n_sub_iteration, 333 compile_model=False, 334 unetr=unetr, 335 instance_loss=instance_seg_loss, 336 instance_metric=instance_seg_loss, 337 early_stopping=early_stopping, 338 mask_prob=mask_prob, 339 ) 340 else: 341 trainer = trainers.SamTrainer( 342 name=name, 343 train_loader=train_loader, 344 val_loader=val_loader, 345 model=model, 346 optimizer=optimizer, 347 device=device, 348 lr_scheduler=scheduler, 349 logger=trainers.SamLogger, 350 log_image_interval=100, 351 mixed_precision=True, 352 convert_inputs=convert_inputs, 353 n_objects_per_batch=n_objects_per_batch, 354 n_sub_iteration=n_sub_iteration, 355 compile_model=False, 356 early_stopping=early_stopping, 357 mask_prob=mask_prob, 358 save_root=save_root, 359 ) 360 361 trainer_fit_params = _get_trainer_fit_params( 362 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 363 ) 364 trainer.fit(**trainer_fit_params) 365 366 t_run = time.time() - t_start 367 hours = int(t_run // 3600) 368 minutes = int(t_run // 60) 369 seconds = int(round(t_run % 60, 0)) 370 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochs
if given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateau
is used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau
. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamW
is used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model
.
373def export_instance_segmentation_model( 374 trained_model_path: Union[str, os.PathLike], 375 output_path: Union[str, os.PathLike], 376 model_type: str, 377 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 378) -> None: 379 """Export a model trained for instance segmentation with `train_instance_segmentation`. 380 381 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 382 It should only be used for automatic segmentation and may not work well for interactive segmentation. 383 384 Args: 385 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 386 output_path: The path where the exported model will be saved. 387 model_type: The model type. 388 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 389 """ 390 trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] 391 392 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 393 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 394 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 395 396 # Load the original state of the model that was used as the basis of instance segmentation training. 397 _, model_state = get_sam_model( 398 model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", 399 ) 400 # Remove the sam prefix if it's in the model state. 401 prefix = "sam." 402 model_state = OrderedDict( 403 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 404 ) 405 406 # Replace the image encoder state. 407 model_state = OrderedDict( 408 [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) 409 for k, v in model_state.items()] 410 ) 411 412 save_state = {"model_state": model_state, "decoder_state": decoder_state} 413 torch.save(save_state, output_path)
Export a model trained for instance segmentation with train_instance_segmentation
.
The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.
Arguments:
- trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
- output_path: The path where the exported model will be saved.
- model_type: The model type.
- initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
416def train_instance_segmentation( 417 name: str, 418 model_type: str, 419 train_loader: DataLoader, 420 val_loader: DataLoader, 421 n_epochs: int = 100, 422 early_stopping: Optional[int] = 10, 423 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 424 metric: Optional[torch.nn.Module] = None, 425 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 426 freeze: Optional[List[str]] = None, 427 device: Optional[Union[str, torch.device]] = None, 428 lr: float = 1e-5, 429 save_root: Optional[Union[str, os.PathLike]] = None, 430 n_iterations: Optional[int] = None, 431 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 432 scheduler_kwargs: Optional[Dict[str, Any]] = None, 433 save_every_kth_epoch: Optional[int] = None, 434 pbar_signals: Optional[QObject] = None, 435 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 436 peft_kwargs: Optional[Dict] = None, 437 ignore_warnings: bool = True, 438 overwrite_training: bool = True, 439 **model_kwargs, 440) -> None: 441 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 442 443 This setting corresponds to training a SAM model with an instance segmentation decoder, 444 without training the model parts for interactive segmentation, 445 i.e. without training the prompt encoder and mask decoder. 446 447 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 448 will not be compatible with the micro_sam functionality. 449 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 450 in a format that is compatible with micro_sam functionality. 451 Note that the exported model should only be used for automatic segmentation via AIS. 452 453 Args: 454 name: The name of the model to be trained. The checkpoint and logs will have this name. 455 model_type: The type of the SAM model. 456 train_loader: The dataloader for training. 457 val_loader: The dataloader for validation. 458 n_epochs: The number of epochs to train for. 459 early_stopping: Enable early stopping after this number of epochs without improvement. 460 By default, the value is set to '10' epochs. 461 loss: The loss function to train the instance segmentation model. 462 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 463 metric: The metric for the instance segmentation training. 464 By default the loss function is used as the metric. 465 checkpoint_path: Path to checkpoint for initializing the SAM model. 466 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 467 By default nothing is frozen and the full model is updated. 468 device: The device to use for training. By default, automatically chooses the best available device to train. 469 lr: The learning rate. By default, set to '1e-5'. 470 save_root: Optional root directory for saving the checkpoints and logs. 471 If not given the current working directory is used. 472 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 473 scheduler_class: The learning rate scheduler to update the learning rate. 474 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 475 scheduler_kwargs: The learning rate scheduler parameters. 476 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 477 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 478 pbar_signals: Controls for napari progress bar. 479 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 480 peft_kwargs: Keyword arguments for the PEFT wrapper class. 481 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 482 overwrite_training: Whether to overwrite the trained model stored at the same location. 483 By default, overwrites the trained model at each run. 484 If set to 'False', it will avoid retraining the model if the previous run was completed. 485 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 486 """ 487 488 with _filter_warnings(ignore_warnings): 489 t_start = time.time() 490 491 sam_model, state = get_trainable_sam_model( 492 model_type=model_type, 493 device=device, 494 checkpoint_path=checkpoint_path, 495 return_state=True, 496 peft_kwargs=peft_kwargs, 497 freeze=freeze, 498 **model_kwargs 499 ) 500 device = get_device(device) 501 model = get_unetr( 502 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 503 ) 504 505 optimizer, scheduler = _get_optimizer_and_scheduler( 506 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 507 ) 508 trainer = torch_em.trainer.DefaultTrainer( 509 name=name, 510 model=model, 511 train_loader=train_loader, 512 val_loader=val_loader, 513 device=device, 514 mixed_precision=True, 515 log_image_interval=50, 516 compile_model=False, 517 save_root=save_root, 518 loss=loss, 519 metric=loss if metric is None else metric, 520 optimizer=optimizer, 521 lr_scheduler=scheduler, 522 early_stopping=early_stopping, 523 ) 524 525 trainer_fit_params = _get_trainer_fit_params( 526 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 527 ) 528 trainer.fit(**trainer_fit_params) 529 530 t_run = time.time() - t_start 531 hours = int(t_run // 3600) 532 minutes = int(t_run // 60) 533 seconds = int(round(t_run % 60, 0)) 534 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Train a UNETR for instance segmentation using the SAM encoder as backbone.
This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.
The checkpoint of the trained model, which will be saved in 'checkpoints/export_instance_segmentation_model
with the path to the checkpoint to export it
in a format that is compatible with micro_sam functionality.
Note that the exported model should only be used for automatic segmentation via AIS.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
- metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochs
if given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateau
is used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau
. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamW
is used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model
.
570def default_sam_dataset( 571 raw_paths: Union[List[FilePath], FilePath], 572 raw_key: Optional[str], 573 label_paths: Union[List[FilePath], FilePath], 574 label_key: Optional[str], 575 patch_shape: Tuple[int], 576 with_segmentation_decoder: bool, 577 with_channels: Optional[bool] = None, 578 train_instance_segmentation_only: bool = False, 579 sampler: Optional[Callable] = None, 580 raw_transform: Optional[Callable] = None, 581 n_samples: Optional[int] = None, 582 is_train: bool = True, 583 min_size: int = 25, 584 max_sampling_attempts: Optional[int] = None, 585 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 586 **kwargs, 587) -> Dataset: 588 """Create a PyTorch Dataset for training a SAM model. 589 590 Args: 591 raw_paths: The path(s) to the image data used for training. 592 Can either be multiple 2D images or volumetric data. 593 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 594 or a glob pattern for selecting multiple files. 595 label_paths: The path(s) to the label data used for training. 596 Can either be multiple 2D images or volumetric data. 597 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 598 or a glob pattern for selecting multiple files. 599 patch_shape: The shape for training patches. 600 with_segmentation_decoder: Whether to train with additional segmentation decoder. 601 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 602 train_instance_segmentation_only: Set this argument to True in order to 603 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 604 sampler: A sampler to reject batches according to a given criterion. 605 raw_transform: Transformation applied to the image data. 606 If not given the data will be cast to 8bit. 607 n_samples: The number of samples for this dataset. 608 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 609 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 610 max_sampling_attempts: Number of sampling attempts to make from a dataset. 611 rois: The region of interest(s) for the data. 612 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 613 614 Returns: 615 The segmentation dataset. 616 """ 617 618 # Check if this dataset should be used for instance segmentation only training. 619 # If yes, we set return_instances to False, since the instance channel must not 620 # be passed for this training mode. 621 return_instances = True 622 if train_instance_segmentation_only: 623 if not with_segmentation_decoder: 624 raise ValueError( 625 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 626 ) 627 return_instances = False 628 629 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 630 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 631 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 632 # and we do not set a default sampler. 633 if sampler is None and not train_instance_segmentation_only: 634 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 635 636 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 637 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 638 639 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 640 # Get valid raw paths to make checks possible. 641 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 642 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 643 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 644 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 645 646 # Load one of the raw inputs to validate whether it is RGB or not. 647 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 648 if test_raw_inputs.ndim == 3: 649 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 650 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 651 # We need to provide a list of inputs to 'ImageCollectionDataset'. 652 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 653 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 654 655 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 656 with_channels = False if with_channels is None else with_channels 657 658 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 659 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 660 with_channels = True if with_channels is None else with_channels 661 662 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 663 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 664 with_channels = False if with_channels is None else with_channels 665 666 # Set the data transformations. 667 if raw_transform is None: 668 raw_transform = require_8bit 669 670 # Prepare the label transform. 671 if with_segmentation_decoder: 672 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 673 distances=True, 674 boundary_distances=True, 675 directed_distances=False, 676 foreground=True, 677 instances=return_instances, 678 min_size=min_size, 679 ) 680 else: 681 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 682 683 # Allow combining label transforms. 684 custom_label_transform = kwargs.pop("label_transform", None) 685 if custom_label_transform is None: 686 label_transform = default_label_transform 687 else: 688 label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform) 689 690 # Check the patch shape to add a singleton if required. 691 patch_shape = _update_patch_shape( 692 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 693 ) 694 695 # Set a minimum number of samples per epoch. 696 if n_samples is None: 697 loader = torch_em.default_segmentation_loader( 698 raw_paths=raw_paths, 699 raw_key=raw_key, 700 label_paths=label_paths, 701 label_key=label_key, 702 batch_size=1, 703 patch_shape=patch_shape, 704 with_channels=with_channels, 705 ndim=2, 706 is_seg_dataset=is_seg_dataset, 707 raw_transform=raw_transform, 708 rois=rois, 709 **kwargs 710 ) 711 n_samples = max(len(loader), 100 if is_train else 5) 712 713 dataset = torch_em.default_segmentation_dataset( 714 raw_paths=raw_paths, 715 raw_key=raw_key, 716 label_paths=label_paths, 717 label_key=label_key, 718 patch_shape=patch_shape, 719 raw_transform=raw_transform, 720 label_transform=label_transform, 721 with_channels=with_channels, 722 ndim=2, 723 sampler=sampler, 724 n_samples=n_samples, 725 is_seg_dataset=is_seg_dataset, 726 rois=rois, 727 **kwargs, 728 ) 729 730 if max_sampling_attempts is not None: 731 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 732 for ds in dataset.datasets: 733 ds.max_sampling_attempts = max_sampling_attempts 734 else: 735 dataset.max_sampling_attempts = max_sampling_attempts 736 737 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
- train_instance_segmentation_only: Set this argument to True in order to
pass the dataset to
train_instance_segmentation
. By default, set to 'False'. - sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
- min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- rois: The region of interest(s) for the data.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset
.
Returns:
The segmentation dataset.
740def default_sam_loader(**kwargs) -> DataLoader: 741 """Create a PyTorch DataLoader for training a SAM model. 742 743 Args: 744 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 745 746 Returns: 747 The DataLoader. 748 """ 749 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 750 751 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 752 # which the users can provide to get their desired segmentation dataset. 753 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 754 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 755 756 ds = default_sam_dataset(**ds_kwargs) 757 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_dataset
or for the PyTorch DataLoader.
Returns:
The DataLoader.
Best training configurations for given hardware resources.
799def train_sam_for_configuration( 800 name: str, 801 train_loader: DataLoader, 802 val_loader: DataLoader, 803 configuration: Optional[str] = None, 804 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 805 with_segmentation_decoder: bool = True, 806 train_instance_segmentation_only: bool = False, 807 model_type: Optional[str] = None, 808 **kwargs, 809) -> None: 810 """Run training for a SAM model with the configuration for a given hardware resource. 811 812 Selects the best training settings for the given configuration. 813 The available configurations are listed in `CONFIGURATIONS`. 814 815 Args: 816 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 817 train_loader: The dataloader for training. 818 val_loader: The dataloader for validation. 819 configuration: The configuration (= name of hardware resource). 820 By default, it is automatically selected for the best VRAM combination. 821 checkpoint_path: Path to checkpoint for initializing the SAM model. 822 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 823 By default, trains with the additional instance segmentation decoder. 824 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 825 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 826 model_type: Over-ride the default model type. 827 This can be used to use one of the micro_sam models as starting point 828 instead of a default sam model. 829 kwargs: Additional keyword parameters that will be passed to `train_sam`. 830 """ 831 if configuration is None: # Automatically choose based on available VRAM combination. 832 configuration = _find_best_configuration() 833 834 if configuration in CONFIGURATIONS: 835 train_kwargs = CONFIGURATIONS[configuration] 836 else: 837 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 838 839 if model_type is None: 840 model_type = train_kwargs.pop("model_type") 841 else: 842 expected_model_type = train_kwargs.pop("model_type") 843 if model_type[:5] != expected_model_type: 844 warnings.warn("You have specified a different model type.") 845 846 train_kwargs.update(**kwargs) 847 if train_instance_segmentation_only: 848 train_instance_segmentation( 849 name=name, 850 train_loader=train_loader, 851 val_loader=val_loader, 852 checkpoint_path=checkpoint_path, 853 with_segmentation_decoder=with_segmentation_decoder, 854 model_type=model_type, 855 **train_kwargs 856 ) 857 else: 858 train_sam( 859 name=name, 860 train_loader=train_loader, 861 val_loader=val_loader, 862 checkpoint_path=checkpoint_path, 863 with_segmentation_decoder=with_segmentation_decoder, 864 model_type=model_type, 865 **train_kwargs 866 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS
.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs folder will have this name.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
using the training implementation
train_instance_segmentation
. By default,train_sam
is used. - model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam
.