micro_sam.training.training

   1import os
   2import time
   3import warnings
   4from glob import glob
   5from tqdm import tqdm
   6from collections import OrderedDict
   7from contextlib import contextmanager, nullcontext
   8from typing import Any, Callable, Dict, List, Optional, Tuple, Union
   9
  10import imageio.v3 as imageio
  11
  12import torch
  13from torch.optim import Optimizer
  14from torch.utils.data import random_split
  15from torch.utils.data import DataLoader, Dataset
  16from torch.optim.lr_scheduler import _LRScheduler
  17
  18import torch_em
  19from torch_em.util import load_data
  20from torch_em.data.datasets.util import split_kwargs
  21
  22from elf.io import open_file
  23
  24try:
  25    from qtpy.QtCore import QObject
  26except Exception:
  27    QObject = Any
  28
  29from . import sam_trainer as trainers
  30from ..instance_segmentation import get_unetr
  31from ..models.peft_sam import ClassicalSurgery
  32from . import joint_sam_trainer as joint_trainers
  33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model
  34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
  35
  36
  37FilePath = Union[str, os.PathLike]
  38
  39
  40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
  41    x, _ = next(iter(loader))
  42
  43    # Raw data: check that we have 1 or 3 channels.
  44    n_channels = x.shape[1]
  45    if n_channels not in (1, 3):
  46        raise ValueError(
  47            "Invalid number of channels for the input data from the data loader. "
  48            f"Expect 1 or 3 channels, got {n_channels}."
  49        )
  50
  51    # Raw data: check that it is between [0, 255]
  52    minval, maxval = x.min(), x.max()
  53    if minval < 0 or minval > 255:
  54        raise ValueError(
  55            "Invalid data range for the input data from the data loader. "
  56            f"The input has to be in range [0, 255], but got minimum value {minval}."
  57        )
  58    if maxval < 1 or maxval > 255:
  59        raise ValueError(
  60            "Invalid data range for the input data from the data loader. "
  61            f"The input has to be in range [0, 255], but got maximum value {maxval}."
  62        )
  63
  64    # Target data: the check depends on whether we train with or without decoder.
  65    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
  66
  67    def _check_instance_channel(instance_channel):
  68        unique_vals = torch.unique(instance_channel)
  69        if (unique_vals < 0).any():
  70            raise ValueError(
  71                "The target channel with the instance segmentation must not have negative values."
  72            )
  73        if len(unique_vals) == 1:
  74            raise ValueError(
  75                "The target channel with the instance segmentation must have at least one instance."
  76            )
  77        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
  78            raise ValueError(
  79                "All values in the target channel with the instance segmentation must be integer."
  80            )
  81
  82    counter = 0
  83    name = "" if name is None else f"'{name}'"
  84    for x, y in tqdm(
  85        loader,
  86        desc=f"Verifying labels in {name} dataloader",
  87        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
  88    ):
  89        n_channels_y = y.shape[1]
  90        if with_segmentation_decoder:
  91            if n_channels_y != 4:
  92                raise ValueError(
  93                    "Invalid number of channels in the target data from the data loader. "
  94                    "Expect 4 channel for training with an instance segmentation decoder, "
  95                    f"but got {n_channels_y} channels."
  96                )
  97            # Check instance channel per sample in a batch
  98            for per_y_sample in y:
  99                _check_instance_channel(per_y_sample[0])
 100
 101            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 102            if targets_min < 0 or targets_min > 1:
 103                raise ValueError(
 104                    "Invalid value range in the target data from the value loader. "
 105                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 106                    f"to be in range [0.0, 1.0], but got min {targets_min}"
 107                )
 108            if targets_max < 0 or targets_max > 1:
 109                raise ValueError(
 110                    "Invalid value range in the target data from the value loader. "
 111                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 112                    f"to be in range [0.0, 1.0], but got max {targets_max}"
 113                )
 114
 115        else:
 116            if n_channels_y != 1:
 117                raise ValueError(
 118                    "Invalid number of channels in the target data from the data loader. "
 119                    "Expect 1 channel for training without an instance segmentation decoder, "
 120                    f"but got {n_channels_y} channels."
 121                )
 122            # Check instance channel per sample in a batch
 123            for per_y_sample in y:
 124                _check_instance_channel(per_y_sample)
 125
 126        counter += 1
 127        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
 128            break
 129
 130
 131# Make the progress bar callbacks compatible with a tqdm progress bar interface.
 132class _ProgressBarWrapper:
 133    def __init__(self, signals):
 134        self._signals = signals
 135        self._total = None
 136
 137    @property
 138    def total(self):
 139        return self._total
 140
 141    @total.setter
 142    def total(self, value):
 143        self._signals.pbar_total.emit(value)
 144        self._total = value
 145
 146    def update(self, steps):
 147        self._signals.pbar_update.emit(steps)
 148
 149    def set_description(self, desc, **kwargs):
 150        self._signals.pbar_description.emit(desc)
 151
 152
 153@contextmanager
 154def _filter_warnings(ignore_warnings):
 155    if ignore_warnings:
 156        with warnings.catch_warnings():
 157            warnings.simplefilter("ignore")
 158            yield
 159    else:
 160        with nullcontext():
 161            yield
 162
 163
 164def _count_parameters(model_parameters):
 165    params = sum(p.numel() for p in model_parameters if p.requires_grad)
 166    params = params / 1e6
 167    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
 168
 169
 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training):
 171    if n_iterations is None:
 172        trainer_fit_params = {"epochs": n_epochs}
 173    else:
 174        trainer_fit_params = {"iterations": n_iterations}
 175
 176    if save_every_kth_epoch is not None:
 177        trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
 178
 179    if pbar_signals is not None:
 180        progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
 181        trainer_fit_params["progress"] = progress_bar_wrapper
 182
 183    # Avoid overwriting a trained model, if desired by the user.
 184    trainer_fit_params["overwrite_training"] = overwrite_training
 185    return trainer_fit_params
 186
 187
 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs):
 189    optimizer = optimizer_class(model_params, lr=lr)
 190    if scheduler_kwargs is None:
 191        scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3}
 192    scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
 193    return optimizer, scheduler
 194
 195
 196def train_sam(
 197    name: str,
 198    model_type: str,
 199    train_loader: DataLoader,
 200    val_loader: DataLoader,
 201    n_epochs: int = 100,
 202    early_stopping: Optional[int] = 10,
 203    n_objects_per_batch: Optional[int] = 25,
 204    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 205    with_segmentation_decoder: bool = True,
 206    freeze: Optional[List[str]] = None,
 207    device: Optional[Union[str, torch.device]] = None,
 208    lr: float = 1e-5,
 209    n_sub_iteration: int = 8,
 210    save_root: Optional[Union[str, os.PathLike]] = None,
 211    mask_prob: float = 0.5,
 212    n_iterations: Optional[int] = None,
 213    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 214    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 215    save_every_kth_epoch: Optional[int] = None,
 216    pbar_signals: Optional[QObject] = None,
 217    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 218    peft_kwargs: Optional[Dict] = None,
 219    ignore_warnings: bool = True,
 220    verify_n_labels_in_loader: Optional[int] = 50,
 221    box_distortion_factor: Optional[float] = 0.025,
 222    overwrite_training: bool = True,
 223    **model_kwargs,
 224) -> None:
 225    """Run training for a SAM model.
 226
 227    Args:
 228        name: The name of the model to be trained. The checkpoint and logs will have this name.
 229        model_type: The type of the SAM model.
 230        train_loader: The dataloader for training.
 231        val_loader: The dataloader for validation.
 232        n_epochs: The number of epochs to train for.
 233        early_stopping: Enable early stopping after this number of epochs without improvement.
 234            By default, the value is set to '10' epochs.
 235        n_objects_per_batch: The number of objects per batch used to compute
 236            the loss for interative segmentation. If None all objects will be used,
 237            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
 238        checkpoint_path: Path to checkpoint for initializing the SAM model.
 239        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 240            By default, trains with the additional instance segmentation decoder.
 241        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
 242            By default nothing is frozen and the full model is updated.
 243        device: The device to use for training. By default, automatically chooses the best available device to train.
 244        lr: The learning rate. By default, set to '1e-5'.
 245        n_sub_iteration: The number of iterative prompts per training iteration.
 246            By default, the number of iterations is set to '8'.
 247        save_root: Optional root directory for saving the checkpoints and logs.
 248            If not given the current working directory is used.
 249        mask_prob: The probability for using a mask as input in a given training sub-iteration.
 250            By default, set to '0.5'.
 251        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 252        scheduler_class: The learning rate scheduler to update the learning rate.
 253            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 254        scheduler_kwargs: The learning rate scheduler parameters.
 255            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 256        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 257        pbar_signals: Controls for napari progress bar.
 258        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 259        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 260        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 261        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
 262            By default, 50 batches of labels are verified from the dataloaders.
 263        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
 264            By default, the distortion factor is set to '0.025'.
 265        overwrite_training: Whether to overwrite the trained model stored at the same location.
 266            By default, overwrites the trained model at each run.
 267            If set to 'False', it will avoid retraining the model if the previous run was completed.
 268        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 269    """
 270    with _filter_warnings(ignore_warnings):
 271
 272        t_start = time.time()
 273
 274        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
 275        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
 276
 277        device = get_device(device)
 278        # Get the trainable segment anything model.
 279        model, state = get_trainable_sam_model(
 280            model_type=model_type,
 281            device=device,
 282            freeze=freeze,
 283            checkpoint_path=checkpoint_path,
 284            return_state=True,
 285            peft_kwargs=peft_kwargs,
 286            **model_kwargs
 287        )
 288
 289        # This class creates all the training data for a batch (inputs, prompts and labels).
 290        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
 291
 292        # Create the UNETR decoder (if train with it) and the optimizer.
 293        if with_segmentation_decoder:
 294
 295            # Get the UNETR.
 296            unetr = get_unetr(
 297                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
 298            )
 299
 300            # Get the parameters for SAM and the decoder from UNETR.
 301            joint_model_params = [params for params in model.parameters()]  # sam parameters
 302            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
 303                if not param_name.startswith("encoder"):
 304                    joint_model_params.append(params)
 305
 306            model_params = joint_model_params
 307        else:
 308            model_params = model.parameters()
 309
 310        optimizer, scheduler = _get_optimizer_and_scheduler(
 311            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
 312        )
 313
 314        # The trainer which performs training and validation.
 315        if with_segmentation_decoder:
 316            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
 317            trainer = joint_trainers.JointSamTrainer(
 318                name=name,
 319                save_root=save_root,
 320                train_loader=train_loader,
 321                val_loader=val_loader,
 322                model=model,
 323                optimizer=optimizer,
 324                device=device,
 325                lr_scheduler=scheduler,
 326                logger=joint_trainers.JointSamLogger,
 327                log_image_interval=100,
 328                mixed_precision=True,
 329                convert_inputs=convert_inputs,
 330                n_objects_per_batch=n_objects_per_batch,
 331                n_sub_iteration=n_sub_iteration,
 332                compile_model=False,
 333                unetr=unetr,
 334                instance_loss=instance_seg_loss,
 335                instance_metric=instance_seg_loss,
 336                early_stopping=early_stopping,
 337                mask_prob=mask_prob,
 338            )
 339        else:
 340            trainer = trainers.SamTrainer(
 341                name=name,
 342                train_loader=train_loader,
 343                val_loader=val_loader,
 344                model=model,
 345                optimizer=optimizer,
 346                device=device,
 347                lr_scheduler=scheduler,
 348                logger=trainers.SamLogger,
 349                log_image_interval=100,
 350                mixed_precision=True,
 351                convert_inputs=convert_inputs,
 352                n_objects_per_batch=n_objects_per_batch,
 353                n_sub_iteration=n_sub_iteration,
 354                compile_model=False,
 355                early_stopping=early_stopping,
 356                mask_prob=mask_prob,
 357                save_root=save_root,
 358            )
 359
 360        trainer_fit_params = _get_trainer_fit_params(
 361            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 362        )
 363        trainer.fit(**trainer_fit_params)
 364
 365        t_run = time.time() - t_start
 366        hours = int(t_run // 3600)
 367        minutes = int(t_run // 60)
 368        seconds = int(round(t_run % 60, 0))
 369        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 370
 371
 372def export_instance_segmentation_model(
 373    trained_model_path: Union[str, os.PathLike],
 374    output_path: Union[str, os.PathLike],
 375    model_type: str,
 376    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 377) -> None:
 378    """Export a model trained for instance segmentation with `train_instance_segmentation`.
 379
 380    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
 381    It should only be used for automatic segmentation and may not work well for interactive segmentation.
 382
 383    Args:
 384        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
 385        output_path: The path where the exported model will be saved.
 386        model_type: The model type.
 387        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
 388    """
 389    trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"]
 390
 391    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
 392    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
 393    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
 394
 395    # Load the original state of the model that was used as the basis of instance segmentation training.
 396    _, model_state = get_sam_model(
 397        model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu",
 398    )
 399    # Remove the sam prefix if it's in the model state.
 400    prefix = "sam."
 401    model_state = OrderedDict(
 402        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
 403    )
 404
 405    # Replace the image encoder state.
 406    model_state = OrderedDict(
 407        [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v)
 408         for k, v in model_state.items()]
 409    )
 410
 411    save_state = {"model_state": model_state, "decoder_state": decoder_state}
 412    torch.save(save_state, output_path)
 413
 414
 415def train_instance_segmentation(
 416    name: str,
 417    model_type: str,
 418    train_loader: DataLoader,
 419    val_loader: DataLoader,
 420    n_epochs: int = 100,
 421    early_stopping: Optional[int] = 10,
 422    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
 423    metric: Optional[torch.nn.Module] = None,
 424    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 425    freeze: Optional[List[str]] = None,
 426    device: Optional[Union[str, torch.device]] = None,
 427    lr: float = 1e-5,
 428    save_root: Optional[Union[str, os.PathLike]] = None,
 429    n_iterations: Optional[int] = None,
 430    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 431    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 432    save_every_kth_epoch: Optional[int] = None,
 433    pbar_signals: Optional[QObject] = None,
 434    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 435    peft_kwargs: Optional[Dict] = None,
 436    ignore_warnings: bool = True,
 437    overwrite_training: bool = True,
 438    **model_kwargs,
 439) -> None:
 440    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
 441
 442    This setting corresponds to training a SAM model with an instance segmentation decoder,
 443    without training the model parts for interactive segmentation,
 444    i.e. without training the prompt encoder and mask decoder.
 445
 446    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
 447    will not be compatible with the micro_sam functionality.
 448    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
 449    in a format that is compatible with micro_sam functionality.
 450    Note that the exported model should only be used for automatic segmentation via AIS.
 451
 452    Args:
 453        name: The name of the model to be trained. The checkpoint and logs will have this name.
 454        model_type: The type of the SAM model.
 455        train_loader: The dataloader for training.
 456        val_loader: The dataloader for validation.
 457        n_epochs: The number of epochs to train for.
 458        early_stopping: Enable early stopping after this number of epochs without improvement.
 459            By default, the value is set to '10' epochs.
 460        loss: The loss function to train the instance segmentation model.
 461            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
 462        metric: The metric for the instance segmentation training.
 463            By default the loss function is used as the metric.
 464        checkpoint_path: Path to checkpoint for initializing the SAM model.
 465        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
 466            By default nothing is frozen and the full model is updated.
 467        device: The device to use for training. By default, automatically chooses the best available device to train.
 468        lr: The learning rate. By default, set to '1e-5'.
 469        save_root: Optional root directory for saving the checkpoints and logs.
 470            If not given the current working directory is used.
 471        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 472        scheduler_class: The learning rate scheduler to update the learning rate.
 473            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 474        scheduler_kwargs: The learning rate scheduler parameters.
 475            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 476        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 477        pbar_signals: Controls for napari progress bar.
 478        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 479        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 480        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 481        overwrite_training: Whether to overwrite the trained model stored at the same location.
 482            By default, overwrites the trained model at each run.
 483            If set to 'False', it will avoid retraining the model if the previous run was completed.
 484        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 485    """
 486
 487    with _filter_warnings(ignore_warnings):
 488        t_start = time.time()
 489
 490        sam_model, state = get_trainable_sam_model(
 491            model_type=model_type,
 492            device=device,
 493            checkpoint_path=checkpoint_path,
 494            return_state=True,
 495            peft_kwargs=peft_kwargs,
 496            freeze=freeze,
 497            **model_kwargs
 498        )
 499        device = get_device(device)
 500        model = get_unetr(
 501            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
 502        )
 503
 504        optimizer, scheduler = _get_optimizer_and_scheduler(
 505            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
 506        )
 507        trainer = torch_em.trainer.DefaultTrainer(
 508            name=name,
 509            model=model,
 510            train_loader=train_loader,
 511            val_loader=val_loader,
 512            device=device,
 513            mixed_precision=True,
 514            log_image_interval=50,
 515            compile_model=False,
 516            save_root=save_root,
 517            loss=loss,
 518            metric=loss if metric is None else metric,
 519            optimizer=optimizer,
 520            lr_scheduler=scheduler,
 521            early_stopping=early_stopping,
 522        )
 523
 524        trainer_fit_params = _get_trainer_fit_params(
 525            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 526        )
 527        trainer.fit(**trainer_fit_params)
 528
 529        t_run = time.time() - t_start
 530        hours = int(t_run // 3600)
 531        minutes = int(t_run // 60)
 532        seconds = int(round(t_run % 60, 0))
 533        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 534
 535
 536def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
 537    if isinstance(raw_paths, (str, os.PathLike)):
 538        path = raw_paths
 539    else:
 540        path = raw_paths[0]
 541    assert isinstance(path, (str, os.PathLike))
 542
 543    # Check the underlying data dimensionality.
 544    if raw_key is None:  # If no key is given then we assume it's an image file.
 545        ndim = imageio.imread(path).ndim
 546    else:  # Otherwise we try to open the file from key.
 547        try:  # First try to open it with elf.
 548            with open_file(path, "r") as f:
 549                ndim = f[raw_key].ndim
 550        except ValueError:  # This may fail for images in a folder with different sizes.
 551            # In that case we read one of the images.
 552            image_path = glob(os.path.join(path, raw_key))[0]
 553            ndim = imageio.imread(image_path).ndim
 554
 555    if not isinstance(patch_shape, tuple):
 556        patch_shape = tuple(patch_shape)
 557
 558    if ndim == 2:
 559        assert len(patch_shape) == 2
 560        return patch_shape
 561    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
 562        return (1,) + patch_shape
 563    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
 564        return (1,) + patch_shape
 565    else:
 566        return patch_shape
 567
 568
 569def default_sam_dataset(
 570    raw_paths: Union[List[FilePath], FilePath],
 571    raw_key: Optional[str],
 572    label_paths: Union[List[FilePath], FilePath],
 573    label_key: Optional[str],
 574    patch_shape: Tuple[int],
 575    with_segmentation_decoder: bool,
 576    with_channels: Optional[bool] = None,
 577    train_instance_segmentation_only: bool = False,
 578    sampler: Optional[Callable] = None,
 579    raw_transform: Optional[Callable] = None,
 580    n_samples: Optional[int] = None,
 581    is_train: bool = True,
 582    min_size: int = 25,
 583    max_sampling_attempts: Optional[int] = None,
 584    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
 585    **kwargs,
 586) -> Dataset:
 587    """Create a PyTorch Dataset for training a SAM model.
 588
 589    Args:
 590        raw_paths: The path(s) to the image data used for training.
 591            Can either be multiple 2D images or volumetric data.
 592        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
 593            or a glob pattern for selecting multiple files.
 594        label_paths: The path(s) to the label data used for training.
 595            Can either be multiple 2D images or volumetric data.
 596        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
 597            or a glob pattern for selecting multiple files.
 598        patch_shape: The shape for training patches.
 599        with_segmentation_decoder: Whether to train with additional segmentation decoder.
 600        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
 601        train_instance_segmentation_only: Set this argument to True in order to
 602            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
 603        sampler: A sampler to reject batches according to a given criterion.
 604        raw_transform: Transformation applied to the image data.
 605            If not given the data will be cast to 8bit.
 606        n_samples: The number of samples for this dataset.
 607        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
 608        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
 609        max_sampling_attempts: Number of sampling attempts to make from a dataset.
 610        rois: The region of interest(s) for the data.
 611        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 612
 613    Returns:
 614        The segmentation dataset.
 615    """
 616
 617    # Check if this dataset should be used for instance segmentation only training.
 618    # If yes, we set return_instances to False, since the instance channel must not
 619    # be passed for this training mode.
 620    return_instances = True
 621    if train_instance_segmentation_only:
 622        if not with_segmentation_decoder:
 623            raise ValueError(
 624                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
 625            )
 626        return_instances = False
 627
 628    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
 629    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
 630    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
 631    # and we do not set a default sampler.
 632    if sampler is None and not train_instance_segmentation_only:
 633        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
 634
 635    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
 636    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
 637
 638    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
 639    # Get valid raw paths to make checks possible.
 640    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
 641        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
 642    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
 643        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
 644
 645    # Load one of the raw inputs to validate whether it is RGB or not.
 646    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
 647    if test_raw_inputs.ndim == 3:
 648        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
 649            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
 650            # We need to provide a list of inputs to 'ImageCollectionDataset'.
 651            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
 652            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
 653
 654            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
 655            with_channels = False if with_channels is None else with_channels
 656
 657        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
 658            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
 659            with_channels = True if with_channels is None else with_channels
 660
 661    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
 662    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
 663    with_channels = False if with_channels is None else with_channels
 664
 665    # Set the data transformations.
 666    if raw_transform is None:
 667        raw_transform = require_8bit
 668
 669    # Prepare the label transform.
 670    if with_segmentation_decoder:
 671        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
 672            distances=True,
 673            boundary_distances=True,
 674            directed_distances=False,
 675            foreground=True,
 676            instances=return_instances,
 677            min_size=min_size,
 678        )
 679    else:
 680        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
 681
 682    # Allow combining label transforms.
 683    custom_label_transform = kwargs.pop("label_transform", None)
 684    if custom_label_transform is None:
 685        label_transform = default_label_transform
 686    else:
 687        label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform)
 688
 689    # Check the patch shape to add a singleton if required.
 690    patch_shape = _update_patch_shape(
 691        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
 692    )
 693
 694    # Set a minimum number of samples per epoch.
 695    if n_samples is None:
 696        loader = torch_em.default_segmentation_loader(
 697            raw_paths=raw_paths,
 698            raw_key=raw_key,
 699            label_paths=label_paths,
 700            label_key=label_key,
 701            batch_size=1,
 702            patch_shape=patch_shape,
 703            with_channels=with_channels,
 704            ndim=2,
 705            is_seg_dataset=is_seg_dataset,
 706            raw_transform=raw_transform,
 707            rois=rois,
 708            **kwargs
 709        )
 710        n_samples = max(len(loader), 100 if is_train else 5)
 711
 712    dataset = torch_em.default_segmentation_dataset(
 713        raw_paths=raw_paths,
 714        raw_key=raw_key,
 715        label_paths=label_paths,
 716        label_key=label_key,
 717        patch_shape=patch_shape,
 718        raw_transform=raw_transform,
 719        label_transform=label_transform,
 720        with_channels=with_channels,
 721        ndim=2,
 722        sampler=sampler,
 723        n_samples=n_samples,
 724        is_seg_dataset=is_seg_dataset,
 725        rois=rois,
 726        **kwargs,
 727    )
 728
 729    if max_sampling_attempts is not None:
 730        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
 731            for ds in dataset.datasets:
 732                ds.max_sampling_attempts = max_sampling_attempts
 733        else:
 734            dataset.max_sampling_attempts = max_sampling_attempts
 735
 736    return dataset
 737
 738
 739def default_sam_loader(**kwargs) -> DataLoader:
 740    """Create a PyTorch DataLoader for training a SAM model.
 741
 742    Args:
 743        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
 744
 745    Returns:
 746        The DataLoader.
 747    """
 748    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
 749
 750    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
 751    # which the users can provide to get their desired segmentation dataset.
 752    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
 753    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
 754
 755    ds = default_sam_dataset(**ds_kwargs)
 756    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
 757
 758
 759CONFIGURATIONS = {
 760    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
 761    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
 762    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
 763    "gtx3080": {
 764        "model_type": "vit_b", "n_objects_per_batch": 5,
 765        "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery}
 766    },
 767    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
 768    "V100": {"model_type": "vit_b"},
 769    "A100": {"model_type": "vit_h"},
 770}
 771"""Best training configurations for given hardware resources.
 772"""
 773
 774
 775def _find_best_configuration():
 776    if torch.cuda.is_available():
 777
 778        # Check how much memory we have and select the best matching GPU
 779        # for the available VRAM size.
 780        _, vram = torch.cuda.mem_get_info()
 781        vram = vram / 1e9  # in GB
 782
 783        # Maybe we can get more configurations in the future.
 784        if vram > 80:  # More than 80 GB: use the A100 configurations.
 785            return "A100"
 786        elif vram > 30:  # More than 30 GB: use the V100 configurations.
 787            return "V100"
 788        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
 789            return "rtx5000"
 790        elif vram > 8:  # More than 8 GB: use the GTX3080 configurations.
 791            return "gtx3080"
 792        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
 793            return "CPU"
 794    else:
 795        return "CPU"
 796
 797
 798def train_sam_for_configuration(
 799    name: str,
 800    train_loader: DataLoader,
 801    val_loader: DataLoader,
 802    configuration: Optional[str] = None,
 803    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 804    with_segmentation_decoder: bool = True,
 805    train_instance_segmentation_only: bool = False,
 806    model_type: Optional[str] = None,
 807    **kwargs,
 808) -> None:
 809    """Run training for a SAM model with the configuration for a given hardware resource.
 810
 811    Selects the best training settings for the given configuration.
 812    The available configurations are listed in `CONFIGURATIONS`.
 813
 814    Args:
 815        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
 816        train_loader: The dataloader for training.
 817        val_loader: The dataloader for validation.
 818        configuration: The configuration (= name of hardware resource).
 819            By default, it is automatically selected for the best VRAM combination.
 820        checkpoint_path: Path to checkpoint for initializing the SAM model.
 821        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 822            By default, trains with the additional instance segmentation decoder.
 823        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
 824            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
 825        model_type: Over-ride the default model type.
 826            This can be used to use one of the micro_sam models as starting point
 827            instead of a default sam model.
 828        kwargs: Additional keyword parameters that will be passed to `train_sam`.
 829    """
 830    if configuration is None:  # Automatically choose based on available VRAM combination.
 831        configuration = _find_best_configuration()
 832
 833    if configuration in CONFIGURATIONS:
 834        train_kwargs = CONFIGURATIONS[configuration]
 835    else:
 836        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
 837
 838    if model_type is None:
 839        model_type = train_kwargs.pop("model_type")
 840    else:
 841        expected_model_type = train_kwargs.pop("model_type")
 842        if model_type[:5] != expected_model_type:
 843            warnings.warn("You have specified a different model type.")
 844
 845    train_kwargs.update(**kwargs)
 846    if train_instance_segmentation_only:
 847        train_instance_segmentation(
 848            name=name,
 849            train_loader=train_loader,
 850            val_loader=val_loader,
 851            checkpoint_path=checkpoint_path,
 852            with_segmentation_decoder=with_segmentation_decoder,
 853            model_type=model_type,
 854            **train_kwargs
 855        )
 856    else:
 857        train_sam(
 858            name=name,
 859            train_loader=train_loader,
 860            val_loader=val_loader,
 861            checkpoint_path=checkpoint_path,
 862            with_segmentation_decoder=with_segmentation_decoder,
 863            model_type=model_type,
 864            **train_kwargs
 865        )
 866
 867
 868def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
 869
 870    # Whether the model is stored in the current working directory or in another location.
 871    if save_root is None:
 872        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
 873
 874    # Get the 'best' model checkpoint ready for export.
 875    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
 876    if not os.path.exists(best_checkpoint):
 877        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
 878
 879    # Export the model if an output path has been given.
 880    if output_path:
 881
 882        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
 883        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
 884            export_custom_sam_model(
 885                checkpoint_path=best_checkpoint,
 886                model_type=model_type[:5],
 887                save_path=output_path,
 888                with_segmentation_decoder=with_segmentation_decoder,
 889            )
 890
 891        # Otherwise we export it as bioimage.io model.
 892        else:
 893            from micro_sam.bioimageio import export_sam_model
 894
 895            # Load image and corresponding labels from the val loader.
 896            with torch.no_grad():
 897                image_data, label_data = next(iter(val_loader))
 898                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
 899
 900                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
 901                if label_data.ndim == 3:
 902                    label_data = label_data[0]  # Gets the channel with instances.
 903                assert image_data.shape == label_data.shape
 904                label_data = label_data.astype("uint32")
 905
 906                export_sam_model(
 907                    image=image_data,
 908                    label_image=label_data,
 909                    model_type=model_type[:5],
 910                    name=checkpoint_name,
 911                    output_path=output_path,
 912                    checkpoint_path=best_checkpoint,
 913                )
 914
 915        # The final path where the model has been stored.
 916        final_path = output_path
 917
 918    else:  # If no exports have been made, inform the user about the best checkpoint.
 919        final_path = best_checkpoint
 920
 921    return final_path
 922
 923
 924def _parse_segmentation_decoder(segmentation_decoder):
 925    if segmentation_decoder in ("None", "none"):
 926        with_segmentation_decoder, train_instance_segmentation_only = False, False
 927    elif segmentation_decoder == "instances":
 928        with_segmentation_decoder, train_instance_segmentation_only = True, False
 929    elif segmentation_decoder == "instances_only":
 930        with_segmentation_decoder, train_instance_segmentation_only = True, True
 931    else:
 932        raise ValueError(
 933            "The 'segmentation_decoder' argument currently supports the values:\n"
 934            f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}."
 935        )
 936    return with_segmentation_decoder, train_instance_segmentation_only
 937
 938
 939def main():
 940    """@private"""
 941    import argparse
 942
 943    available_models = list(get_model_names())
 944    available_models = ", ".join(available_models)
 945
 946    available_configurations = list(CONFIGURATIONS.keys())
 947    available_configurations = ", ".join(available_configurations)
 948
 949    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
 950
 951    # Images and labels for training.
 952    parser.add_argument(
 953        "--images", required=True, type=str, nargs="*",
 954        help="Filepath to images or the directory where the image data is stored."
 955    )
 956    parser.add_argument(
 957        "--labels", required=True, type=str, nargs="*",
 958        help="Filepath to ground-truth labels or the directory where the label data is stored."
 959    )
 960    parser.add_argument(
 961        "--image_key", type=str, default=None,
 962        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
 963    )
 964    parser.add_argument(
 965        "--label_key", type=str, default=None,
 966        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
 967    )
 968
 969    # Images and labels for validation.
 970    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
 971    # Users can choose to have their explicit validation set via this feature as well.
 972    parser.add_argument(
 973        "--val_images", type=str, nargs="*",
 974        help="Filepath to images for validation or the directory where the image data is stored."
 975    )
 976    parser.add_argument(
 977        "--val_labels", type=str, nargs="*",
 978        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
 979    )
 980    parser.add_argument(
 981        "--val_image_key", type=str, default=None,
 982        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
 983    )
 984    parser.add_argument(
 985        "--val_label_key", type=str, default=None,
 986        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
 987    )
 988
 989    # Other necessary stuff for training.
 990    parser.add_argument(
 991        "--configuration", type=str, default=_find_best_configuration(),
 992        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
 993    )
 994
 995    def none_or_str(value):
 996        if value.lower() == 'none':
 997            return None
 998        return value
 999
1000    # This could be extended to train for semantic segmentation or other options.
1001    parser.add_argument(
1002        "--segmentation_decoder", type=none_or_str, default="instances",
1003        help="Whether to finetune Segment Anything Model with an additional segmentation decoder. "
1004        "The following options are possible:\n"
1005        "- 'instances' to train with an additional decoder for automatic instance segmentation. "
1006        "  This option enables using the automatic instance segmentation (AIS) mode.\n"
1007        "- 'instances_only' to train only the instance segmentation decoder. "
1008        "  In this case the parts of SAM that are used for interactive segmentation will not be trained.\n"
1009        "- 'None' to train without an additional segmentation decoder."
1010        "  This options trains only the parts of the original SAM.\n"
1011        "By default the option 'instances' is used."
1012    )
1013
1014    # Optional advanced settings a user can opt to change the values for.
1015    parser.add_argument(
1016        "-d", "--device", type=str, default=None,
1017        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
1018        "By default the most performant available device will be selected."
1019    )
1020    parser.add_argument(
1021        "--patch_shape", type=int, nargs="*", default=(512, 512),
1022        help="The choice of patch shape for training Segment Anything Model. "
1023        "By default, a patch size of 512x512 is used."
1024    )
1025    parser.add_argument(
1026        "-m", "--model_type", type=str, default=None,
1027        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
1028    )
1029    parser.add_argument(
1030        "--checkpoint_path", type=str, default=None,
1031        help="Checkpoint from which the SAM model will be loaded for finetuning."
1032    )
1033    parser.add_argument(
1034        "-s", "--save_root", type=str, default=None,
1035        help="The directory where the trained models and corresponding logs will be stored. "
1036        "By default, there are stored in your current working directory."
1037    )
1038    parser.add_argument(
1039        "--trained_model_name", type=str, default="sam_model",
1040        help="The custom name of trained model sub-folder. Allows users to have several trained models "
1041        "under the same 'save_root'."
1042    )
1043    parser.add_argument(
1044        "--output_path", type=str, default=None,
1045        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
1046    )
1047    parser.add_argument(
1048        "--n_epochs", type=int, default=100,
1049        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
1050    )
1051    parser.add_argument(
1052        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
1053    )
1054    parser.add_argument(
1055        "--batch_size", type=int, default=1,
1056        help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1."
1057    )
1058    parser.add_argument(
1059        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
1060        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
1061        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
1062    )
1063
1064    args = parser.parse_args()
1065
1066    # 1. Get all necessary stuff for training.
1067    checkpoint_name = args.trained_model_name
1068    config = args.configuration
1069    model_type = args.model_type
1070    checkpoint_path = args.checkpoint_path
1071    batch_size = args.batch_size
1072    patch_shape = args.patch_shape
1073    epochs = args.n_epochs
1074    num_workers = args.num_workers
1075    device = args.device
1076    save_root = args.save_root
1077    output_path = args.output_path
1078    with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder)
1079
1080    # Get image paths and corresponding keys.
1081    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
1082    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
1083
1084    # 2. Prepare the dataloaders.
1085
1086    # If the user wants to preprocess the inputs, we allow the possibility to do so.
1087    _raw_transform = get_raw_transform(args.preprocess)
1088
1089    # Get the dataset with files for training.
1090    dataset = default_sam_dataset(
1091        raw_paths=train_images,
1092        raw_key=train_image_key,
1093        label_paths=train_gt,
1094        label_key=train_gt_key,
1095        patch_shape=patch_shape,
1096        with_segmentation_decoder=with_segmentation_decoder,
1097        raw_transform=_raw_transform,
1098        train_instance_segmentation_only=train_instance_segmentation_only,
1099    )
1100
1101    # If val images are not exclusively provided, we create a val split from the training data.
1102    if val_images is None:
1103        assert val_gt is None and val_image_key is None and val_gt_key is None
1104        # Use 10% of the dataset for validation - at least one image - for validation.
1105        n_val = max(1, int(0.1 * len(dataset)))
1106        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
1107
1108    else:  # If val images provided, we create a new dataset for it.
1109        train_dataset = dataset
1110        val_dataset = default_sam_dataset(
1111            raw_paths=val_images,
1112            raw_key=val_image_key,
1113            label_paths=val_gt,
1114            label_key=val_gt_key,
1115            patch_shape=patch_shape,
1116            with_segmentation_decoder=with_segmentation_decoder,
1117            train_instance_segmentation_only=train_instance_segmentation_only,
1118            raw_transform=_raw_transform,
1119        )
1120
1121    # Get the dataloaders from the datasets.
1122    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1123    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
1124
1125    # 3. Train the Segment Anything Model.
1126
1127    # Get a valid model and other necessary parameters for training.
1128    if model_type is not None and model_type not in available_models:
1129        raise ValueError(f"'{model_type}' is not a valid choice of model.")
1130    if config is not None and config not in available_configurations:
1131        raise ValueError(f"'{config}' is not a valid choice of configuration.")
1132
1133    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
1134        model_type = CONFIGURATIONS[config]["model_type"]
1135
1136    train_sam_for_configuration(
1137        name=checkpoint_name,
1138        configuration=config,
1139        model_type=model_type,
1140        train_loader=train_loader,
1141        val_loader=val_loader,
1142        n_epochs=epochs,
1143        checkpoint_path=checkpoint_path,
1144        with_segmentation_decoder=with_segmentation_decoder,
1145        freeze=None,  # TODO: Allow for PEFT.
1146        device=device,
1147        save_root=save_root,
1148        peft_kwargs=None,  # TODO: Allow for PEFT.
1149        train_instance_segmentation_only=train_instance_segmentation_only,
1150    )
1151
1152    # 4. Export the model, if desired by the user
1153    if train_instance_segmentation_only and output_path:
1154        trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt")
1155        export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path)
1156        final_path = output_path
1157    else:
1158        final_path = _export_helper(
1159            save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader,
1160        )
1161
1162    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, overwrite_training: bool = True, **model_kwargs) -> None:
197def train_sam(
198    name: str,
199    model_type: str,
200    train_loader: DataLoader,
201    val_loader: DataLoader,
202    n_epochs: int = 100,
203    early_stopping: Optional[int] = 10,
204    n_objects_per_batch: Optional[int] = 25,
205    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
206    with_segmentation_decoder: bool = True,
207    freeze: Optional[List[str]] = None,
208    device: Optional[Union[str, torch.device]] = None,
209    lr: float = 1e-5,
210    n_sub_iteration: int = 8,
211    save_root: Optional[Union[str, os.PathLike]] = None,
212    mask_prob: float = 0.5,
213    n_iterations: Optional[int] = None,
214    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
215    scheduler_kwargs: Optional[Dict[str, Any]] = None,
216    save_every_kth_epoch: Optional[int] = None,
217    pbar_signals: Optional[QObject] = None,
218    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
219    peft_kwargs: Optional[Dict] = None,
220    ignore_warnings: bool = True,
221    verify_n_labels_in_loader: Optional[int] = 50,
222    box_distortion_factor: Optional[float] = 0.025,
223    overwrite_training: bool = True,
224    **model_kwargs,
225) -> None:
226    """Run training for a SAM model.
227
228    Args:
229        name: The name of the model to be trained. The checkpoint and logs will have this name.
230        model_type: The type of the SAM model.
231        train_loader: The dataloader for training.
232        val_loader: The dataloader for validation.
233        n_epochs: The number of epochs to train for.
234        early_stopping: Enable early stopping after this number of epochs without improvement.
235            By default, the value is set to '10' epochs.
236        n_objects_per_batch: The number of objects per batch used to compute
237            the loss for interative segmentation. If None all objects will be used,
238            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
239        checkpoint_path: Path to checkpoint for initializing the SAM model.
240        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
241            By default, trains with the additional instance segmentation decoder.
242        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
243            By default nothing is frozen and the full model is updated.
244        device: The device to use for training. By default, automatically chooses the best available device to train.
245        lr: The learning rate. By default, set to '1e-5'.
246        n_sub_iteration: The number of iterative prompts per training iteration.
247            By default, the number of iterations is set to '8'.
248        save_root: Optional root directory for saving the checkpoints and logs.
249            If not given the current working directory is used.
250        mask_prob: The probability for using a mask as input in a given training sub-iteration.
251            By default, set to '0.5'.
252        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
253        scheduler_class: The learning rate scheduler to update the learning rate.
254            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
255        scheduler_kwargs: The learning rate scheduler parameters.
256            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
257        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
258        pbar_signals: Controls for napari progress bar.
259        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
260        peft_kwargs: Keyword arguments for the PEFT wrapper class.
261        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
262        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
263            By default, 50 batches of labels are verified from the dataloaders.
264        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
265            By default, the distortion factor is set to '0.025'.
266        overwrite_training: Whether to overwrite the trained model stored at the same location.
267            By default, overwrites the trained model at each run.
268            If set to 'False', it will avoid retraining the model if the previous run was completed.
269        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
270    """
271    with _filter_warnings(ignore_warnings):
272
273        t_start = time.time()
274
275        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
276        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
277
278        device = get_device(device)
279        # Get the trainable segment anything model.
280        model, state = get_trainable_sam_model(
281            model_type=model_type,
282            device=device,
283            freeze=freeze,
284            checkpoint_path=checkpoint_path,
285            return_state=True,
286            peft_kwargs=peft_kwargs,
287            **model_kwargs
288        )
289
290        # This class creates all the training data for a batch (inputs, prompts and labels).
291        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
292
293        # Create the UNETR decoder (if train with it) and the optimizer.
294        if with_segmentation_decoder:
295
296            # Get the UNETR.
297            unetr = get_unetr(
298                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
299            )
300
301            # Get the parameters for SAM and the decoder from UNETR.
302            joint_model_params = [params for params in model.parameters()]  # sam parameters
303            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
304                if not param_name.startswith("encoder"):
305                    joint_model_params.append(params)
306
307            model_params = joint_model_params
308        else:
309            model_params = model.parameters()
310
311        optimizer, scheduler = _get_optimizer_and_scheduler(
312            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
313        )
314
315        # The trainer which performs training and validation.
316        if with_segmentation_decoder:
317            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
318            trainer = joint_trainers.JointSamTrainer(
319                name=name,
320                save_root=save_root,
321                train_loader=train_loader,
322                val_loader=val_loader,
323                model=model,
324                optimizer=optimizer,
325                device=device,
326                lr_scheduler=scheduler,
327                logger=joint_trainers.JointSamLogger,
328                log_image_interval=100,
329                mixed_precision=True,
330                convert_inputs=convert_inputs,
331                n_objects_per_batch=n_objects_per_batch,
332                n_sub_iteration=n_sub_iteration,
333                compile_model=False,
334                unetr=unetr,
335                instance_loss=instance_seg_loss,
336                instance_metric=instance_seg_loss,
337                early_stopping=early_stopping,
338                mask_prob=mask_prob,
339            )
340        else:
341            trainer = trainers.SamTrainer(
342                name=name,
343                train_loader=train_loader,
344                val_loader=val_loader,
345                model=model,
346                optimizer=optimizer,
347                device=device,
348                lr_scheduler=scheduler,
349                logger=trainers.SamLogger,
350                log_image_interval=100,
351                mixed_precision=True,
352                convert_inputs=convert_inputs,
353                n_objects_per_batch=n_objects_per_batch,
354                n_sub_iteration=n_sub_iteration,
355                compile_model=False,
356                early_stopping=early_stopping,
357                mask_prob=mask_prob,
358                save_root=save_root,
359            )
360
361        trainer_fit_params = _get_trainer_fit_params(
362            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
363        )
364        trainer.fit(**trainer_fit_params)
365
366        t_run = time.time() - t_start
367        hours = int(t_run // 3600)
368        minutes = int(t_run // 60)
369        seconds = int(round(t_run % 60, 0))
370        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def export_instance_segmentation_model( trained_model_path: Union[str, os.PathLike], output_path: Union[str, os.PathLike], model_type: str, initial_checkpoint_path: Union[os.PathLike, str, NoneType] = None) -> None:
373def export_instance_segmentation_model(
374    trained_model_path: Union[str, os.PathLike],
375    output_path: Union[str, os.PathLike],
376    model_type: str,
377    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
378) -> None:
379    """Export a model trained for instance segmentation with `train_instance_segmentation`.
380
381    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
382    It should only be used for automatic segmentation and may not work well for interactive segmentation.
383
384    Args:
385        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
386        output_path: The path where the exported model will be saved.
387        model_type: The model type.
388        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
389    """
390    trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"]
391
392    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
393    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
394    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
395
396    # Load the original state of the model that was used as the basis of instance segmentation training.
397    _, model_state = get_sam_model(
398        model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu",
399    )
400    # Remove the sam prefix if it's in the model state.
401    prefix = "sam."
402    model_state = OrderedDict(
403        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
404    )
405
406    # Replace the image encoder state.
407    model_state = OrderedDict(
408        [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v)
409         for k, v in model_state.items()]
410    )
411
412    save_state = {"model_state": model_state, "decoder_state": decoder_state}
413    torch.save(save_state, output_path)

Export a model trained for instance segmentation with train_instance_segmentation.

The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.

Arguments:
  • trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
  • output_path: The path where the exported model will be saved.
  • model_type: The model type.
  • initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
def train_instance_segmentation( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, loss: torch.nn.modules.module.Module = DiceBasedDistanceLoss( (foreground_loss): DiceLoss() (distance_loss): DiceLoss() ), metric: Optional[torch.nn.modules.module.Module] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, save_root: Union[os.PathLike, str, NoneType] = None, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, overwrite_training: bool = True, **model_kwargs) -> None:
416def train_instance_segmentation(
417    name: str,
418    model_type: str,
419    train_loader: DataLoader,
420    val_loader: DataLoader,
421    n_epochs: int = 100,
422    early_stopping: Optional[int] = 10,
423    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
424    metric: Optional[torch.nn.Module] = None,
425    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
426    freeze: Optional[List[str]] = None,
427    device: Optional[Union[str, torch.device]] = None,
428    lr: float = 1e-5,
429    save_root: Optional[Union[str, os.PathLike]] = None,
430    n_iterations: Optional[int] = None,
431    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
432    scheduler_kwargs: Optional[Dict[str, Any]] = None,
433    save_every_kth_epoch: Optional[int] = None,
434    pbar_signals: Optional[QObject] = None,
435    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
436    peft_kwargs: Optional[Dict] = None,
437    ignore_warnings: bool = True,
438    overwrite_training: bool = True,
439    **model_kwargs,
440) -> None:
441    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
442
443    This setting corresponds to training a SAM model with an instance segmentation decoder,
444    without training the model parts for interactive segmentation,
445    i.e. without training the prompt encoder and mask decoder.
446
447    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
448    will not be compatible with the micro_sam functionality.
449    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
450    in a format that is compatible with micro_sam functionality.
451    Note that the exported model should only be used for automatic segmentation via AIS.
452
453    Args:
454        name: The name of the model to be trained. The checkpoint and logs will have this name.
455        model_type: The type of the SAM model.
456        train_loader: The dataloader for training.
457        val_loader: The dataloader for validation.
458        n_epochs: The number of epochs to train for.
459        early_stopping: Enable early stopping after this number of epochs without improvement.
460            By default, the value is set to '10' epochs.
461        loss: The loss function to train the instance segmentation model.
462            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
463        metric: The metric for the instance segmentation training.
464            By default the loss function is used as the metric.
465        checkpoint_path: Path to checkpoint for initializing the SAM model.
466        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
467            By default nothing is frozen and the full model is updated.
468        device: The device to use for training. By default, automatically chooses the best available device to train.
469        lr: The learning rate. By default, set to '1e-5'.
470        save_root: Optional root directory for saving the checkpoints and logs.
471            If not given the current working directory is used.
472        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
473        scheduler_class: The learning rate scheduler to update the learning rate.
474            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
475        scheduler_kwargs: The learning rate scheduler parameters.
476            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
477        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
478        pbar_signals: Controls for napari progress bar.
479        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
480        peft_kwargs: Keyword arguments for the PEFT wrapper class.
481        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
482        overwrite_training: Whether to overwrite the trained model stored at the same location.
483            By default, overwrites the trained model at each run.
484            If set to 'False', it will avoid retraining the model if the previous run was completed.
485        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
486    """
487
488    with _filter_warnings(ignore_warnings):
489        t_start = time.time()
490
491        sam_model, state = get_trainable_sam_model(
492            model_type=model_type,
493            device=device,
494            checkpoint_path=checkpoint_path,
495            return_state=True,
496            peft_kwargs=peft_kwargs,
497            freeze=freeze,
498            **model_kwargs
499        )
500        device = get_device(device)
501        model = get_unetr(
502            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
503        )
504
505        optimizer, scheduler = _get_optimizer_and_scheduler(
506            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
507        )
508        trainer = torch_em.trainer.DefaultTrainer(
509            name=name,
510            model=model,
511            train_loader=train_loader,
512            val_loader=val_loader,
513            device=device,
514            mixed_precision=True,
515            log_image_interval=50,
516            compile_model=False,
517            save_root=save_root,
518            loss=loss,
519            metric=loss if metric is None else metric,
520            optimizer=optimizer,
521            lr_scheduler=scheduler,
522            early_stopping=early_stopping,
523        )
524
525        trainer_fit_params = _get_trainer_fit_params(
526            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
527        )
528        trainer.fit(**trainer_fit_params)
529
530        t_run = time.time() - t_start
531        hours = int(t_run // 3600)
532        minutes = int(t_run // 60)
533        seconds = int(round(t_run % 60, 0))
534        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Train a UNETR for instance segmentation using the SAM encoder as backbone.

This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.

The checkpoint of the trained model, which will be saved in 'checkpoints/', will not be compatible with the micro_sam functionality. You can call the function export_instance_segmentation_model with the path to the checkpoint to export it in a format that is compatible with micro_sam functionality. Note that the exported model should only be used for automatic segmentation via AIS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
  • metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, train_instance_segmentation_only: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, rois: Union[slice, Tuple[slice, ...], NoneType] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
570def default_sam_dataset(
571    raw_paths: Union[List[FilePath], FilePath],
572    raw_key: Optional[str],
573    label_paths: Union[List[FilePath], FilePath],
574    label_key: Optional[str],
575    patch_shape: Tuple[int],
576    with_segmentation_decoder: bool,
577    with_channels: Optional[bool] = None,
578    train_instance_segmentation_only: bool = False,
579    sampler: Optional[Callable] = None,
580    raw_transform: Optional[Callable] = None,
581    n_samples: Optional[int] = None,
582    is_train: bool = True,
583    min_size: int = 25,
584    max_sampling_attempts: Optional[int] = None,
585    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
586    **kwargs,
587) -> Dataset:
588    """Create a PyTorch Dataset for training a SAM model.
589
590    Args:
591        raw_paths: The path(s) to the image data used for training.
592            Can either be multiple 2D images or volumetric data.
593        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
594            or a glob pattern for selecting multiple files.
595        label_paths: The path(s) to the label data used for training.
596            Can either be multiple 2D images or volumetric data.
597        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
598            or a glob pattern for selecting multiple files.
599        patch_shape: The shape for training patches.
600        with_segmentation_decoder: Whether to train with additional segmentation decoder.
601        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
602        train_instance_segmentation_only: Set this argument to True in order to
603            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
604        sampler: A sampler to reject batches according to a given criterion.
605        raw_transform: Transformation applied to the image data.
606            If not given the data will be cast to 8bit.
607        n_samples: The number of samples for this dataset.
608        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
609        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
610        max_sampling_attempts: Number of sampling attempts to make from a dataset.
611        rois: The region of interest(s) for the data.
612        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
613
614    Returns:
615        The segmentation dataset.
616    """
617
618    # Check if this dataset should be used for instance segmentation only training.
619    # If yes, we set return_instances to False, since the instance channel must not
620    # be passed for this training mode.
621    return_instances = True
622    if train_instance_segmentation_only:
623        if not with_segmentation_decoder:
624            raise ValueError(
625                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
626            )
627        return_instances = False
628
629    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
630    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
631    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
632    # and we do not set a default sampler.
633    if sampler is None and not train_instance_segmentation_only:
634        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
635
636    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
637    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
638
639    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
640    # Get valid raw paths to make checks possible.
641    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
642        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
643    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
644        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
645
646    # Load one of the raw inputs to validate whether it is RGB or not.
647    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
648    if test_raw_inputs.ndim == 3:
649        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
650            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
651            # We need to provide a list of inputs to 'ImageCollectionDataset'.
652            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
653            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
654
655            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
656            with_channels = False if with_channels is None else with_channels
657
658        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
659            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
660            with_channels = True if with_channels is None else with_channels
661
662    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
663    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
664    with_channels = False if with_channels is None else with_channels
665
666    # Set the data transformations.
667    if raw_transform is None:
668        raw_transform = require_8bit
669
670    # Prepare the label transform.
671    if with_segmentation_decoder:
672        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
673            distances=True,
674            boundary_distances=True,
675            directed_distances=False,
676            foreground=True,
677            instances=return_instances,
678            min_size=min_size,
679        )
680    else:
681        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
682
683    # Allow combining label transforms.
684    custom_label_transform = kwargs.pop("label_transform", None)
685    if custom_label_transform is None:
686        label_transform = default_label_transform
687    else:
688        label_transform = torch_em.transform.generic.Compose(custom_label_transform, default_label_transform)
689
690    # Check the patch shape to add a singleton if required.
691    patch_shape = _update_patch_shape(
692        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
693    )
694
695    # Set a minimum number of samples per epoch.
696    if n_samples is None:
697        loader = torch_em.default_segmentation_loader(
698            raw_paths=raw_paths,
699            raw_key=raw_key,
700            label_paths=label_paths,
701            label_key=label_key,
702            batch_size=1,
703            patch_shape=patch_shape,
704            with_channels=with_channels,
705            ndim=2,
706            is_seg_dataset=is_seg_dataset,
707            raw_transform=raw_transform,
708            rois=rois,
709            **kwargs
710        )
711        n_samples = max(len(loader), 100 if is_train else 5)
712
713    dataset = torch_em.default_segmentation_dataset(
714        raw_paths=raw_paths,
715        raw_key=raw_key,
716        label_paths=label_paths,
717        label_key=label_key,
718        patch_shape=patch_shape,
719        raw_transform=raw_transform,
720        label_transform=label_transform,
721        with_channels=with_channels,
722        ndim=2,
723        sampler=sampler,
724        n_samples=n_samples,
725        is_seg_dataset=is_seg_dataset,
726        rois=rois,
727        **kwargs,
728    )
729
730    if max_sampling_attempts is not None:
731        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
732            for ds in dataset.datasets:
733                ds.max_sampling_attempts = max_sampling_attempts
734        else:
735            dataset.max_sampling_attempts = max_sampling_attempts
736
737    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • train_instance_segmentation_only: Set this argument to True in order to pass the dataset to train_instance_segmentation. By default, set to 'False'.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
  • min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • rois: The region of interest(s) for the data.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
740def default_sam_loader(**kwargs) -> DataLoader:
741    """Create a PyTorch DataLoader for training a SAM model.
742
743    Args:
744        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
745
746    Returns:
747        The DataLoader.
748    """
749    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
750
751    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
752    # which the users can provide to get their desired segmentation dataset.
753    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
754    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
755
756    ds = default_sam_dataset(**ds_kwargs)
757    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'gtx3080': {'model_type': 'vit_b', 'n_objects_per_batch': 5, 'peft_kwargs': {'attention_layers_to_update': [11], 'peft_module': <class 'micro_sam.models.peft_sam.ClassicalSurgery'>}}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

def train_sam_for_configuration( name: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, configuration: Optional[str] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, train_instance_segmentation_only: bool = False, model_type: Optional[str] = None, **kwargs) -> None:
799def train_sam_for_configuration(
800    name: str,
801    train_loader: DataLoader,
802    val_loader: DataLoader,
803    configuration: Optional[str] = None,
804    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
805    with_segmentation_decoder: bool = True,
806    train_instance_segmentation_only: bool = False,
807    model_type: Optional[str] = None,
808    **kwargs,
809) -> None:
810    """Run training for a SAM model with the configuration for a given hardware resource.
811
812    Selects the best training settings for the given configuration.
813    The available configurations are listed in `CONFIGURATIONS`.
814
815    Args:
816        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
817        train_loader: The dataloader for training.
818        val_loader: The dataloader for validation.
819        configuration: The configuration (= name of hardware resource).
820            By default, it is automatically selected for the best VRAM combination.
821        checkpoint_path: Path to checkpoint for initializing the SAM model.
822        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
823            By default, trains with the additional instance segmentation decoder.
824        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
825            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
826        model_type: Over-ride the default model type.
827            This can be used to use one of the micro_sam models as starting point
828            instead of a default sam model.
829        kwargs: Additional keyword parameters that will be passed to `train_sam`.
830    """
831    if configuration is None:  # Automatically choose based on available VRAM combination.
832        configuration = _find_best_configuration()
833
834    if configuration in CONFIGURATIONS:
835        train_kwargs = CONFIGURATIONS[configuration]
836    else:
837        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
838
839    if model_type is None:
840        model_type = train_kwargs.pop("model_type")
841    else:
842        expected_model_type = train_kwargs.pop("model_type")
843        if model_type[:5] != expected_model_type:
844            warnings.warn("You have specified a different model type.")
845
846    train_kwargs.update(**kwargs)
847    if train_instance_segmentation_only:
848        train_instance_segmentation(
849            name=name,
850            train_loader=train_loader,
851            val_loader=val_loader,
852            checkpoint_path=checkpoint_path,
853            with_segmentation_decoder=with_segmentation_decoder,
854            model_type=model_type,
855            **train_kwargs
856        )
857    else:
858        train_sam(
859            name=name,
860            train_loader=train_loader,
861            val_loader=val_loader,
862            checkpoint_path=checkpoint_path,
863            with_segmentation_decoder=with_segmentation_decoder,
864            model_type=model_type,
865            **train_kwargs
866        )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs folder will have this name.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation using the training implementation train_instance_segmentation. By default, train_sam is used.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.