micro_sam.training.training

   1import os
   2import time
   3import warnings
   4from glob import glob
   5from tqdm import tqdm
   6from collections import OrderedDict
   7from contextlib import contextmanager, nullcontext
   8from typing import Any, Callable, Dict, List, Optional, Tuple, Union
   9
  10import imageio.v3 as imageio
  11import numpy as np
  12import torch
  13from torch.optim import Optimizer
  14from torch.utils.data import random_split
  15from torch.utils.data import DataLoader, Dataset
  16from torch.optim.lr_scheduler import _LRScheduler
  17
  18import torch_em
  19from torch_em.util import load_data
  20from torch_em.data.datasets.util import split_kwargs
  21
  22from elf.io import open_file
  23
  24try:
  25    from qtpy.QtCore import QObject
  26except Exception:
  27    QObject = Any
  28
  29from . import sam_trainer as trainers
  30from ..instance_segmentation import get_unetr
  31from ..models.peft_sam import ClassicalSurgery
  32from . import joint_sam_trainer as joint_trainers
  33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model
  34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
  35
  36
  37FilePath = Union[str, os.PathLike]
  38
  39
  40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
  41    x, _ = next(iter(loader))
  42
  43    # Raw data: check that we have 1 or 3 channels.
  44    n_channels = x.shape[1]
  45    if n_channels not in (1, 3):
  46        raise ValueError(
  47            "Invalid number of channels for the input data from the data loader. "
  48            f"Expect 1 or 3 channels, got {n_channels}."
  49        )
  50
  51    # Raw data: check that it is between [0, 255]
  52    minval, maxval = x.min(), x.max()
  53    if minval < 0 or minval > 255:
  54        raise ValueError(
  55            "Invalid data range for the input data from the data loader. "
  56            f"The input has to be in range [0, 255], but got minimum value {minval}."
  57        )
  58    if maxval < 1 or maxval > 255:
  59        raise ValueError(
  60            "Invalid data range for the input data from the data loader. "
  61            f"The input has to be in range [0, 255], but got maximum value {maxval}."
  62        )
  63
  64    # Target data: the check depends on whether we train with or without decoder.
  65    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
  66
  67    def _check_instance_channel(instance_channel):
  68        unique_vals = torch.unique(instance_channel)
  69        if (unique_vals < 0).any():
  70            raise ValueError(
  71                "The target channel with the instance segmentation must not have negative values."
  72            )
  73        if len(unique_vals) == 1:
  74            raise ValueError(
  75                "The target channel with the instance segmentation must have at least one instance."
  76            )
  77        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
  78            raise ValueError(
  79                "All values in the target channel with the instance segmentation must be integer."
  80            )
  81
  82    counter = 0
  83    name = "" if name is None else f"'{name}'"
  84    for x, y in tqdm(
  85        loader,
  86        desc=f"Verifying labels in {name} dataloader",
  87        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
  88    ):
  89        n_channels_y = y.shape[1]
  90        if with_segmentation_decoder:
  91            if n_channels_y != 4:
  92                raise ValueError(
  93                    "Invalid number of channels in the target data from the data loader. "
  94                    "Expect 4 channel for training with an instance segmentation decoder, "
  95                    f"but got {n_channels_y} channels."
  96                )
  97            # Check instance channel per sample in a batch
  98            for per_y_sample in y:
  99                _check_instance_channel(per_y_sample[0])
 100
 101            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 102            if targets_min < 0 or targets_min > 1:
 103                raise ValueError(
 104                    "Invalid value range in the target data from the value loader. "
 105                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 106                    f"to be in range [0.0, 1.0], but got min {targets_min}"
 107                )
 108            if targets_max < 0 or targets_max > 1:
 109                raise ValueError(
 110                    "Invalid value range in the target data from the value loader. "
 111                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 112                    f"to be in range [0.0, 1.0], but got max {targets_max}"
 113                )
 114
 115        else:
 116            if n_channels_y != 1:
 117                raise ValueError(
 118                    "Invalid number of channels in the target data from the data loader. "
 119                    "Expect 1 channel for training without an instance segmentation decoder, "
 120                    f"but got {n_channels_y} channels."
 121                )
 122            # Check instance channel per sample in a batch
 123            for per_y_sample in y:
 124                _check_instance_channel(per_y_sample)
 125
 126        counter += 1
 127        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
 128            break
 129
 130
 131# Make the progress bar callbacks compatible with a tqdm progress bar interface.
 132class _ProgressBarWrapper:
 133    def __init__(self, signals):
 134        self._signals = signals
 135        self._total = None
 136
 137    @property
 138    def total(self):
 139        return self._total
 140
 141    @total.setter
 142    def total(self, value):
 143        self._signals.pbar_total.emit(value)
 144        self._total = value
 145
 146    def update(self, steps):
 147        self._signals.pbar_update.emit(steps)
 148
 149    def set_description(self, desc, **kwargs):
 150        self._signals.pbar_description.emit(desc)
 151
 152
 153@contextmanager
 154def _filter_warnings(ignore_warnings):
 155    if ignore_warnings:
 156        with warnings.catch_warnings():
 157            warnings.simplefilter("ignore")
 158            yield
 159    else:
 160        with nullcontext():
 161            yield
 162
 163
 164def _count_parameters(model_parameters):
 165    params = sum(p.numel() for p in model_parameters if p.requires_grad)
 166    params = params / 1e6
 167    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
 168
 169
 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training):
 171    if n_iterations is None:
 172        trainer_fit_params = {"epochs": n_epochs}
 173    else:
 174        trainer_fit_params = {"iterations": n_iterations}
 175
 176    if save_every_kth_epoch is not None:
 177        trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
 178
 179    if pbar_signals is not None:
 180        progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
 181        trainer_fit_params["progress"] = progress_bar_wrapper
 182
 183    # Avoid overwriting a trained model, if desired by the user.
 184    trainer_fit_params["overwrite_training"] = overwrite_training
 185    return trainer_fit_params
 186
 187
 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs):
 189    optimizer = optimizer_class(model_params, lr=lr)
 190    if scheduler_kwargs is None:
 191        scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3}
 192    scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
 193    return optimizer, scheduler
 194
 195
 196def train_sam(
 197    name: str,
 198    model_type: str,
 199    train_loader: DataLoader,
 200    val_loader: DataLoader,
 201    n_epochs: int = 100,
 202    early_stopping: Optional[int] = 10,
 203    n_objects_per_batch: Optional[int] = 25,
 204    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 205    with_segmentation_decoder: bool = True,
 206    freeze: Optional[List[str]] = None,
 207    device: Optional[Union[str, torch.device]] = None,
 208    lr: float = 1e-5,
 209    n_sub_iteration: int = 8,
 210    save_root: Optional[Union[str, os.PathLike]] = None,
 211    mask_prob: float = 0.5,
 212    n_iterations: Optional[int] = None,
 213    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 214    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 215    save_every_kth_epoch: Optional[int] = None,
 216    pbar_signals: Optional[QObject] = None,
 217    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 218    peft_kwargs: Optional[Dict] = None,
 219    ignore_warnings: bool = True,
 220    verify_n_labels_in_loader: Optional[int] = 50,
 221    box_distortion_factor: Optional[float] = 0.025,
 222    overwrite_training: bool = True,
 223    **model_kwargs,
 224) -> None:
 225    """Run training for a SAM model.
 226
 227    Args:
 228        name: The name of the model to be trained. The checkpoint and logs will have this name.
 229        model_type: The type of the SAM model.
 230        train_loader: The dataloader for training.
 231        val_loader: The dataloader for validation.
 232        n_epochs: The number of epochs to train for.
 233        early_stopping: Enable early stopping after this number of epochs without improvement.
 234            By default, the value is set to '10' epochs.
 235        n_objects_per_batch: The number of objects per batch used to compute
 236            the loss for interative segmentation. If None all objects will be used,
 237            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
 238        checkpoint_path: Path to checkpoint for initializing the SAM model.
 239        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 240            By default, trains with the additional instance segmentation decoder.
 241        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
 242            By default nothing is frozen and the full model is updated.
 243        device: The device to use for training. By default, automatically chooses the best available device to train.
 244        lr: The learning rate. By default, set to '1e-5'.
 245        n_sub_iteration: The number of iterative prompts per training iteration.
 246            By default, the number of iterations is set to '8'.
 247        save_root: Optional root directory for saving the checkpoints and logs.
 248            If not given the current working directory is used.
 249        mask_prob: The probability for using a mask as input in a given training sub-iteration.
 250            By default, set to '0.5'.
 251        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 252        scheduler_class: The learning rate scheduler to update the learning rate.
 253            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 254        scheduler_kwargs: The learning rate scheduler parameters.
 255            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 256        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 257        pbar_signals: Controls for napari progress bar.
 258        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 259        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 260        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 261        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
 262            By default, 50 batches of labels are verified from the dataloaders.
 263        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
 264            By default, the distortion factor is set to '0.025'.
 265        overwrite_training: Whether to overwrite the trained model stored at the same location.
 266            By default, overwrites the trained model at each run.
 267            If set to 'False', it will avoid retraining the model if the previous run was completed.
 268        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 269    """
 270    with _filter_warnings(ignore_warnings):
 271
 272        t_start = time.time()
 273        if verify_n_labels_in_loader is not None:
 274            _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
 275            _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
 276
 277        device = get_device(device)
 278        # Get the trainable segment anything model.
 279        model, state = get_trainable_sam_model(
 280            model_type=model_type,
 281            device=device,
 282            freeze=freeze,
 283            checkpoint_path=checkpoint_path,
 284            return_state=True,
 285            peft_kwargs=peft_kwargs,
 286            **model_kwargs
 287        )
 288
 289        # This class creates all the training data for a batch (inputs, prompts and labels).
 290        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
 291
 292        # Create the UNETR decoder (if train with it) and the optimizer.
 293        if with_segmentation_decoder:
 294
 295            # Get the UNETR.
 296            unetr = get_unetr(
 297                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
 298            )
 299
 300            # Get the parameters for SAM and the decoder from UNETR.
 301            joint_model_params = [params for params in model.parameters()]  # sam parameters
 302            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
 303                if not param_name.startswith("encoder"):
 304                    joint_model_params.append(params)
 305
 306            model_params = joint_model_params
 307        else:
 308            model_params = model.parameters()
 309
 310        optimizer, scheduler = _get_optimizer_and_scheduler(
 311            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
 312        )
 313
 314        # The trainer which performs training and validation.
 315        if with_segmentation_decoder:
 316            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
 317            trainer = joint_trainers.JointSamTrainer(
 318                name=name,
 319                save_root=save_root,
 320                train_loader=train_loader,
 321                val_loader=val_loader,
 322                model=model,
 323                optimizer=optimizer,
 324                device=device,
 325                lr_scheduler=scheduler,
 326                logger=joint_trainers.JointSamLogger,
 327                log_image_interval=100,
 328                mixed_precision=True,
 329                convert_inputs=convert_inputs,
 330                n_objects_per_batch=n_objects_per_batch,
 331                n_sub_iteration=n_sub_iteration,
 332                compile_model=False,
 333                unetr=unetr,
 334                instance_loss=instance_seg_loss,
 335                instance_metric=instance_seg_loss,
 336                early_stopping=early_stopping,
 337                mask_prob=mask_prob,
 338            )
 339        else:
 340            trainer = trainers.SamTrainer(
 341                name=name,
 342                train_loader=train_loader,
 343                val_loader=val_loader,
 344                model=model,
 345                optimizer=optimizer,
 346                device=device,
 347                lr_scheduler=scheduler,
 348                logger=trainers.SamLogger,
 349                log_image_interval=100,
 350                mixed_precision=True,
 351                convert_inputs=convert_inputs,
 352                n_objects_per_batch=n_objects_per_batch,
 353                n_sub_iteration=n_sub_iteration,
 354                compile_model=False,
 355                early_stopping=early_stopping,
 356                mask_prob=mask_prob,
 357                save_root=save_root,
 358            )
 359
 360        trainer_fit_params = _get_trainer_fit_params(
 361            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 362        )
 363        trainer.fit(**trainer_fit_params)
 364
 365        t_run = time.time() - t_start
 366        hours = int(t_run // 3600)
 367        minutes = int(t_run // 60)
 368        seconds = int(round(t_run % 60, 0))
 369        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 370
 371
 372def export_instance_segmentation_model(
 373    trained_model_path: Union[str, os.PathLike],
 374    output_path: Union[str, os.PathLike],
 375    model_type: str,
 376    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 377) -> None:
 378    """Export a model trained for instance segmentation with `train_instance_segmentation`.
 379
 380    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
 381    It should only be used for automatic segmentation and may not work well for interactive segmentation.
 382
 383    Args:
 384        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
 385        output_path: The path where the exported model will be saved.
 386        model_type: The model type.
 387        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
 388    """
 389    trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"]
 390
 391    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
 392    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
 393    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
 394
 395    # Load the original state of the model that was used as the basis of instance segmentation training.
 396    _, model_state = get_sam_model(
 397        model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu",
 398    )
 399    # Remove the sam prefix if it's in the model state.
 400    prefix = "sam."
 401    model_state = OrderedDict(
 402        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
 403    )
 404
 405    # Replace the image encoder state.
 406    model_state = OrderedDict(
 407        [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v)
 408         for k, v in model_state.items()]
 409    )
 410
 411    save_state = {"model_state": model_state, "decoder_state": decoder_state}
 412    torch.save(save_state, output_path)
 413
 414
 415def train_instance_segmentation(
 416    name: str,
 417    model_type: str,
 418    train_loader: DataLoader,
 419    val_loader: DataLoader,
 420    n_epochs: int = 100,
 421    early_stopping: Optional[int] = 10,
 422    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
 423    metric: Optional[torch.nn.Module] = None,
 424    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 425    freeze: Optional[List[str]] = None,
 426    device: Optional[Union[str, torch.device]] = None,
 427    lr: float = 1e-5,
 428    save_root: Optional[Union[str, os.PathLike]] = None,
 429    n_iterations: Optional[int] = None,
 430    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 431    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 432    save_every_kth_epoch: Optional[int] = None,
 433    pbar_signals: Optional[QObject] = None,
 434    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 435    peft_kwargs: Optional[Dict] = None,
 436    ignore_warnings: bool = True,
 437    overwrite_training: bool = True,
 438    **model_kwargs,
 439) -> None:
 440    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
 441
 442    This setting corresponds to training a SAM model with an instance segmentation decoder,
 443    without training the model parts for interactive segmentation,
 444    i.e. without training the prompt encoder and mask decoder.
 445
 446    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
 447    will not be compatible with the micro_sam functionality.
 448    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
 449    in a format that is compatible with micro_sam functionality.
 450    Note that the exported model should only be used for automatic segmentation via AIS.
 451
 452    Args:
 453        name: The name of the model to be trained. The checkpoint and logs will have this name.
 454        model_type: The type of the SAM model.
 455        train_loader: The dataloader for training.
 456        val_loader: The dataloader for validation.
 457        n_epochs: The number of epochs to train for.
 458        early_stopping: Enable early stopping after this number of epochs without improvement.
 459            By default, the value is set to '10' epochs.
 460        loss: The loss function to train the instance segmentation model.
 461            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
 462        metric: The metric for the instance segmentation training.
 463            By default the loss function is used as the metric.
 464        checkpoint_path: Path to checkpoint for initializing the SAM model.
 465        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
 466            By default nothing is frozen and the full model is updated.
 467        device: The device to use for training. By default, automatically chooses the best available device to train.
 468        lr: The learning rate. By default, set to '1e-5'.
 469        save_root: Optional root directory for saving the checkpoints and logs.
 470            If not given the current working directory is used.
 471        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 472        scheduler_class: The learning rate scheduler to update the learning rate.
 473            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 474        scheduler_kwargs: The learning rate scheduler parameters.
 475            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 476        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 477        pbar_signals: Controls for napari progress bar.
 478        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 479        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 480        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 481        overwrite_training: Whether to overwrite the trained model stored at the same location.
 482            By default, overwrites the trained model at each run.
 483            If set to 'False', it will avoid retraining the model if the previous run was completed.
 484        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 485    """
 486
 487    with _filter_warnings(ignore_warnings):
 488        t_start = time.time()
 489
 490        sam_model, state = get_trainable_sam_model(
 491            model_type=model_type,
 492            device=device,
 493            checkpoint_path=checkpoint_path,
 494            return_state=True,
 495            peft_kwargs=peft_kwargs,
 496            freeze=freeze,
 497            **model_kwargs
 498        )
 499        device = get_device(device)
 500        model = get_unetr(
 501            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
 502        )
 503
 504        optimizer, scheduler = _get_optimizer_and_scheduler(
 505            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
 506        )
 507        trainer = torch_em.trainer.DefaultTrainer(
 508            name=name,
 509            model=model,
 510            train_loader=train_loader,
 511            val_loader=val_loader,
 512            device=device,
 513            mixed_precision=True,
 514            log_image_interval=50,
 515            compile_model=False,
 516            save_root=save_root,
 517            loss=loss,
 518            metric=loss if metric is None else metric,
 519            optimizer=optimizer,
 520            lr_scheduler=scheduler,
 521            early_stopping=early_stopping,
 522        )
 523
 524        trainer_fit_params = _get_trainer_fit_params(
 525            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 526        )
 527        trainer.fit(**trainer_fit_params)
 528
 529        t_run = time.time() - t_start
 530        hours = int(t_run // 3600)
 531        minutes = int(t_run // 60)
 532        seconds = int(round(t_run % 60, 0))
 533        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 534
 535
 536def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
 537    if isinstance(raw_paths, (str, os.PathLike)):
 538        path = raw_paths
 539    else:
 540        path = raw_paths[0]
 541    assert isinstance(path, (str, os.PathLike, np.ndarray, torch.Tensor))
 542
 543    # Check the underlying data dimensionality.
 544    if raw_key is None and isinstance(path, (np.ndarray, torch.Tensor)):
 545        ndim = path.ndim
 546    elif raw_key is None:  # If no key is given and this is not a tensor/array then we assume it's an image file.
 547        ndim = imageio.imread(path).ndim
 548    else:  # Otherwise we try to open the file from key.
 549        try:  # First try to open it with elf.
 550            with open_file(path, "r") as f:
 551                ndim = f[raw_key].ndim
 552        except ValueError:  # This may fail for images in a folder with different sizes.
 553            # In that case we read one of the images.
 554            image_path = glob(os.path.join(path, raw_key))[0]
 555            ndim = imageio.imread(image_path).ndim
 556
 557    if not isinstance(patch_shape, tuple):
 558        patch_shape = tuple(patch_shape)
 559
 560    if ndim == 2:
 561        assert len(patch_shape) == 2
 562        return patch_shape
 563    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
 564        return (1,) + patch_shape
 565    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
 566        return (1,) + patch_shape
 567    else:
 568        return patch_shape
 569
 570
 571def _determine_dataset_and_channels(raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs):
 572    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
 573    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
 574
 575    # Check if the input data is in numpy or torch. In this case determine we use
 576    # the image collection dataset heuristic (torch_em will figure it out),
 577    # and we determine the number of channels.
 578    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)):
 579        is_seg_dataset = False
 580        if with_channels is None:
 581            with_channels = raw_paths[0].ndim == 3 and raw_paths.shape[-1] == 3
 582        return is_seg_dataset, with_channels
 583
 584    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
 585    # Get valid raw paths to make checks possible.
 586    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
 587        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
 588    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
 589        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
 590
 591    # Load one of the raw inputs to validate whether it is RGB or not.
 592    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
 593    if test_raw_inputs.ndim == 3:
 594        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
 595            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
 596            # We need to provide a list of inputs to 'ImageCollectionDataset'.
 597            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
 598            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
 599
 600            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
 601            with_channels = False if with_channels is None else with_channels
 602
 603        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
 604            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
 605            with_channels = True if with_channels is None else with_channels
 606
 607    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
 608    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
 609    with_channels = False if with_channels is None else with_channels
 610
 611    return is_seg_dataset, with_channels
 612
 613
 614def default_sam_dataset(
 615    raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
 616    raw_key: Optional[str],
 617    label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
 618    label_key: Optional[str],
 619    patch_shape: Tuple[int],
 620    with_segmentation_decoder: bool,
 621    with_channels: Optional[bool] = None,
 622    train_instance_segmentation_only: bool = False,
 623    sampler: Optional[Callable] = None,
 624    raw_transform: Optional[Callable] = None,
 625    n_samples: Optional[int] = None,
 626    is_train: bool = True,
 627    min_size: int = 25,
 628    max_sampling_attempts: Optional[int] = None,
 629    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
 630    is_multi_tensor: bool = True,
 631    **kwargs,
 632) -> Dataset:
 633    """Create a PyTorch Dataset for training a SAM model.
 634
 635    Args:
 636        raw_paths: The path(s) to the image data used for training.
 637            Can either be multiple 2D images or volumetric data.
 638            The data can also be passed as a list of numpy arrays or torch tensors.
 639        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
 640            or a glob pattern for selecting multiple files.
 641            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
 642        label_paths: The path(s) to the label data used for training.
 643            Can either be multiple 2D images or volumetric data.
 644            The data can also be passed as a list of numpy arrays or torch tensors.
 645        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
 646            or a glob pattern for selecting multiple files.
 647            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
 648        patch_shape: The shape for training patches.
 649        with_segmentation_decoder: Whether to train with additional segmentation decoder.
 650        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
 651        train_instance_segmentation_only: Set this argument to True in order to
 652            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
 653        sampler: A sampler to reject batches according to a given criterion.
 654        raw_transform: Transformation applied to the image data.
 655            If not given the data will be cast to 8bit.
 656        n_samples: The number of samples for this dataset.
 657        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
 658        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
 659        max_sampling_attempts: Number of sampling attempts to make from a dataset.
 660        rois: The region of interest(s) for the data.
 661        is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
 662        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 663
 664    Returns:
 665        The segmentation dataset.
 666    """
 667
 668    # Check if this dataset should be used for instance segmentation only training.
 669    # If yes, we set return_instances to False, since the instance channel must not
 670    # be passed for this training mode.
 671    return_instances = True
 672    if train_instance_segmentation_only:
 673        if not with_segmentation_decoder:
 674            raise ValueError(
 675                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
 676            )
 677        return_instances = False
 678
 679    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
 680    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
 681    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
 682    # and we do not set a default sampler.
 683    if sampler is None and not train_instance_segmentation_only:
 684        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
 685
 686    is_seg_dataset, with_channels = _determine_dataset_and_channels(
 687        raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs
 688    )
 689
 690    # Set the data transformations.
 691    if raw_transform is None:
 692        raw_transform = require_8bit
 693
 694    # Prepare the label transform.
 695    if with_segmentation_decoder:
 696        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
 697            distances=True,
 698            boundary_distances=True,
 699            directed_distances=False,
 700            foreground=True,
 701            instances=return_instances,
 702            min_size=min_size,
 703        )
 704    else:
 705        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
 706
 707    # Allow combining label transforms.
 708    custom_label_transform = kwargs.pop("label_transform", None)
 709    if custom_label_transform is None:
 710        label_transform = default_label_transform
 711    else:
 712        label_transform = torch_em.transform.generic.Compose(
 713            custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
 714        )
 715
 716    # Check the patch shape to add a singleton if required.
 717    patch_shape = _update_patch_shape(
 718        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
 719    )
 720
 721    # Set a minimum number of samples per epoch.
 722    if n_samples is None:
 723        loader = torch_em.default_segmentation_loader(
 724            raw_paths=raw_paths,
 725            raw_key=raw_key,
 726            label_paths=label_paths,
 727            label_key=label_key,
 728            batch_size=1,
 729            patch_shape=patch_shape,
 730            with_channels=with_channels,
 731            ndim=2,
 732            is_seg_dataset=is_seg_dataset,
 733            raw_transform=raw_transform,
 734            rois=rois,
 735            **kwargs
 736        )
 737        n_samples = max(len(loader), 100 if is_train else 5)
 738
 739    dataset = torch_em.default_segmentation_dataset(
 740        raw_paths=raw_paths,
 741        raw_key=raw_key,
 742        label_paths=label_paths,
 743        label_key=label_key,
 744        patch_shape=patch_shape,
 745        raw_transform=raw_transform,
 746        label_transform=label_transform,
 747        with_channels=with_channels,
 748        ndim=2,
 749        sampler=sampler,
 750        n_samples=n_samples,
 751        is_seg_dataset=is_seg_dataset,
 752        rois=rois,
 753        **kwargs,
 754    )
 755
 756    if max_sampling_attempts is not None:
 757        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
 758            for ds in dataset.datasets:
 759                ds.max_sampling_attempts = max_sampling_attempts
 760        else:
 761            dataset.max_sampling_attempts = max_sampling_attempts
 762
 763    return dataset
 764
 765
 766def default_sam_loader(**kwargs) -> DataLoader:
 767    """Create a PyTorch DataLoader for training a SAM model.
 768
 769    Args:
 770        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
 771
 772    Returns:
 773        The DataLoader.
 774    """
 775    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
 776
 777    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
 778    # which the users can provide to get their desired segmentation dataset.
 779    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
 780    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
 781
 782    ds = default_sam_dataset(**ds_kwargs)
 783    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
 784
 785
 786CONFIGURATIONS = {
 787    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
 788    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
 789    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
 790    "gtx3080": {
 791        "model_type": "vit_b", "n_objects_per_batch": 5,
 792        "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery}
 793    },
 794    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
 795    "V100": {"model_type": "vit_b"},
 796    "A100": {"model_type": "vit_h"},
 797}
 798"""Best training configurations for given hardware resources.
 799"""
 800
 801
 802def _find_best_configuration():
 803    if torch.cuda.is_available():
 804
 805        # Check how much memory we have and select the best matching GPU
 806        # for the available VRAM size.
 807        _, vram = torch.cuda.mem_get_info()
 808        vram = vram / 1e9  # in GB
 809
 810        # Maybe we can get more configurations in the future.
 811        if vram > 80:  # More than 80 GB: use the A100 configurations.
 812            return "A100"
 813        elif vram > 30:  # More than 30 GB: use the V100 configurations.
 814            return "V100"
 815        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
 816            return "rtx5000"
 817        elif vram > 8:  # More than 8 GB: use the GTX3080 configurations.
 818            return "gtx3080"
 819        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
 820            return "CPU"
 821    else:
 822        return "CPU"
 823
 824
 825def train_sam_for_configuration(
 826    name: str,
 827    train_loader: DataLoader,
 828    val_loader: DataLoader,
 829    configuration: Optional[str] = None,
 830    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 831    with_segmentation_decoder: bool = True,
 832    train_instance_segmentation_only: bool = False,
 833    model_type: Optional[str] = None,
 834    **kwargs,
 835) -> None:
 836    """Run training for a SAM model with the configuration for a given hardware resource.
 837
 838    Selects the best training settings for the given configuration.
 839    The available configurations are listed in `CONFIGURATIONS`.
 840
 841    Args:
 842        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
 843        train_loader: The dataloader for training.
 844        val_loader: The dataloader for validation.
 845        configuration: The configuration (= name of hardware resource).
 846            By default, it is automatically selected for the best VRAM combination.
 847        checkpoint_path: Path to checkpoint for initializing the SAM model.
 848        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 849            By default, trains with the additional instance segmentation decoder.
 850        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
 851            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
 852        model_type: Over-ride the default model type.
 853            This can be used to use one of the micro_sam models as starting point
 854            instead of a default sam model.
 855        kwargs: Additional keyword parameters that will be passed to `train_sam`.
 856    """
 857    if configuration is None:  # Automatically choose based on available VRAM combination.
 858        configuration = _find_best_configuration()
 859
 860    if configuration in CONFIGURATIONS:
 861        train_kwargs = CONFIGURATIONS[configuration]
 862    else:
 863        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
 864
 865    if model_type is None:
 866        model_type = train_kwargs.pop("model_type")
 867    else:
 868        expected_model_type = train_kwargs.pop("model_type")
 869        if model_type[:5] != expected_model_type:
 870            warnings.warn("You have specified a different model type.")
 871
 872    train_kwargs.update(**kwargs)
 873    if train_instance_segmentation_only:
 874        instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs)
 875        model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs)
 876        instance_seg_kwargs.update(**model_kwargs)
 877
 878        train_instance_segmentation(
 879            name=name,
 880            train_loader=train_loader,
 881            val_loader=val_loader,
 882            checkpoint_path=checkpoint_path,
 883            model_type=model_type,
 884            **instance_seg_kwargs,
 885        )
 886    else:
 887        train_sam(
 888            name=name,
 889            train_loader=train_loader,
 890            val_loader=val_loader,
 891            checkpoint_path=checkpoint_path,
 892            with_segmentation_decoder=with_segmentation_decoder,
 893            model_type=model_type,
 894            **train_kwargs
 895        )
 896
 897
 898def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
 899
 900    # Whether the model is stored in the current working directory or in another location.
 901    if save_root is None:
 902        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
 903
 904    # Get the 'best' model checkpoint ready for export.
 905    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
 906    if not os.path.exists(best_checkpoint):
 907        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
 908
 909    # Export the model if an output path has been given.
 910    if output_path:
 911
 912        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
 913        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
 914            export_custom_sam_model(
 915                checkpoint_path=best_checkpoint,
 916                model_type=model_type[:5],
 917                save_path=output_path,
 918                with_segmentation_decoder=with_segmentation_decoder,
 919            )
 920
 921        # Otherwise we export it as bioimage.io model.
 922        else:
 923            from micro_sam.bioimageio import export_sam_model
 924
 925            # Load image and corresponding labels from the val loader.
 926            with torch.no_grad():
 927                image_data, label_data = next(iter(val_loader))
 928                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
 929
 930                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
 931                if label_data.ndim == 3:
 932                    label_data = label_data[0]  # Gets the channel with instances.
 933                assert image_data.shape == label_data.shape
 934                label_data = label_data.astype("uint32")
 935
 936                export_sam_model(
 937                    image=image_data,
 938                    label_image=label_data,
 939                    model_type=model_type[:5],
 940                    name=checkpoint_name,
 941                    output_path=output_path,
 942                    checkpoint_path=best_checkpoint,
 943                )
 944
 945        # The final path where the model has been stored.
 946        final_path = output_path
 947
 948    else:  # If no exports have been made, inform the user about the best checkpoint.
 949        final_path = best_checkpoint
 950
 951    return final_path
 952
 953
 954def _parse_segmentation_decoder(segmentation_decoder):
 955    if segmentation_decoder in ("None", "none"):
 956        with_segmentation_decoder, train_instance_segmentation_only = False, False
 957    elif segmentation_decoder == "instances":
 958        with_segmentation_decoder, train_instance_segmentation_only = True, False
 959    elif segmentation_decoder == "instances_only":
 960        with_segmentation_decoder, train_instance_segmentation_only = True, True
 961    else:
 962        raise ValueError(
 963            "The 'segmentation_decoder' argument currently supports the values:\n"
 964            f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}."
 965        )
 966    return with_segmentation_decoder, train_instance_segmentation_only
 967
 968
 969def main():
 970    """@private"""
 971    import argparse
 972
 973    available_models = list(get_model_names())
 974    available_models = ", ".join(available_models)
 975
 976    available_configurations = list(CONFIGURATIONS.keys())
 977    available_configurations = ", ".join(available_configurations)
 978
 979    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
 980
 981    # Images and labels for training.
 982    parser.add_argument(
 983        "--images", required=True, type=str, nargs="*",
 984        help="Filepath to images or the directory where the image data is stored."
 985    )
 986    parser.add_argument(
 987        "--labels", required=True, type=str, nargs="*",
 988        help="Filepath to ground-truth labels or the directory where the label data is stored."
 989    )
 990    parser.add_argument(
 991        "--image_key", type=str, default=None,
 992        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
 993    )
 994    parser.add_argument(
 995        "--label_key", type=str, default=None,
 996        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
 997    )
 998
 999    # Images and labels for validation.
1000    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
1001    # Users can choose to have their explicit validation set via this feature as well.
1002    parser.add_argument(
1003        "--val_images", type=str, nargs="*",
1004        help="Filepath to images for validation or the directory where the image data is stored."
1005    )
1006    parser.add_argument(
1007        "--val_labels", type=str, nargs="*",
1008        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
1009    )
1010    parser.add_argument(
1011        "--val_image_key", type=str, default=None,
1012        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
1013    )
1014    parser.add_argument(
1015        "--val_label_key", type=str, default=None,
1016        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
1017    )
1018
1019    # Other necessary stuff for training.
1020    parser.add_argument(
1021        "--configuration", type=str, default=_find_best_configuration(),
1022        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
1023    )
1024
1025    def none_or_str(value):
1026        if value.lower() == 'none':
1027            return None
1028        return value
1029
1030    # This could be extended to train for semantic segmentation or other options.
1031    parser.add_argument(
1032        "--segmentation_decoder", type=none_or_str, default="instances",
1033        help="Whether to finetune Segment Anything Model with an additional segmentation decoder. "
1034        "The following options are possible:\n"
1035        "- 'instances' to train with an additional decoder for automatic instance segmentation. "
1036        "  This option enables using the automatic instance segmentation (AIS) mode.\n"
1037        "- 'instances_only' to train only the instance segmentation decoder. "
1038        "  In this case the parts of SAM that are used for interactive segmentation will not be trained.\n"
1039        "- 'None' to train without an additional segmentation decoder."
1040        "  This options trains only the parts of the original SAM.\n"
1041        "By default the option 'instances' is used."
1042    )
1043
1044    # Optional advanced settings a user can opt to change the values for.
1045    parser.add_argument(
1046        "-d", "--device", type=str, default=None,
1047        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
1048        "By default the most performant available device will be selected."
1049    )
1050    parser.add_argument(
1051        "--patch_shape", type=int, nargs="*", default=(512, 512),
1052        help="The choice of patch shape for training Segment Anything Model. "
1053        "By default, a patch size of 512x512 is used."
1054    )
1055    parser.add_argument(
1056        "-m", "--model_type", type=str, default=None,
1057        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
1058    )
1059    parser.add_argument(
1060        "--checkpoint_path", type=str, default=None,
1061        help="Checkpoint from which the SAM model will be loaded for finetuning."
1062    )
1063    parser.add_argument(
1064        "-s", "--save_root", type=str, default=None,
1065        help="The directory where the trained models and corresponding logs will be stored. "
1066        "By default, there are stored in your current working directory."
1067    )
1068    parser.add_argument(
1069        "--trained_model_name", type=str, default="sam_model",
1070        help="The custom name of trained model sub-folder. Allows users to have several trained models "
1071        "under the same 'save_root'."
1072    )
1073    parser.add_argument(
1074        "--output_path", type=str, default=None,
1075        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
1076    )
1077    parser.add_argument(
1078        "--n_epochs", type=int, default=100,
1079        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
1080    )
1081    parser.add_argument(
1082        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
1083    )
1084    parser.add_argument(
1085        "--batch_size", type=int, default=1,
1086        help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1."
1087    )
1088    parser.add_argument(
1089        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
1090        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
1091        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
1092    )
1093
1094    args = parser.parse_args()
1095
1096    # 1. Get all necessary stuff for training.
1097    checkpoint_name = args.trained_model_name
1098    config = args.configuration
1099    model_type = args.model_type
1100    checkpoint_path = args.checkpoint_path
1101    batch_size = args.batch_size
1102    patch_shape = args.patch_shape
1103    epochs = args.n_epochs
1104    num_workers = args.num_workers
1105    device = args.device
1106    save_root = args.save_root
1107    output_path = args.output_path
1108    with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder)
1109
1110    # Get image paths and corresponding keys.
1111    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
1112    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
1113
1114    # 2. Prepare the dataloaders.
1115
1116    # If the user wants to preprocess the inputs, we allow the possibility to do so.
1117    _raw_transform = get_raw_transform(args.preprocess)
1118
1119    # Get the dataset with files for training.
1120    dataset = default_sam_dataset(
1121        raw_paths=train_images,
1122        raw_key=train_image_key,
1123        label_paths=train_gt,
1124        label_key=train_gt_key,
1125        patch_shape=patch_shape,
1126        with_segmentation_decoder=with_segmentation_decoder,
1127        raw_transform=_raw_transform,
1128        train_instance_segmentation_only=train_instance_segmentation_only,
1129    )
1130
1131    # If val images are not exclusively provided, we create a val split from the training data.
1132    if val_images is None:
1133        assert val_gt is None and val_image_key is None and val_gt_key is None
1134        # Use 10% of the dataset for validation - at least one image - for validation.
1135        n_val = max(1, int(0.1 * len(dataset)))
1136        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
1137
1138    else:  # If val images provided, we create a new dataset for it.
1139        train_dataset = dataset
1140        val_dataset = default_sam_dataset(
1141            raw_paths=val_images,
1142            raw_key=val_image_key,
1143            label_paths=val_gt,
1144            label_key=val_gt_key,
1145            patch_shape=patch_shape,
1146            with_segmentation_decoder=with_segmentation_decoder,
1147            train_instance_segmentation_only=train_instance_segmentation_only,
1148            raw_transform=_raw_transform,
1149        )
1150
1151    # Get the dataloaders from the datasets.
1152    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1153    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
1154
1155    # 3. Train the Segment Anything Model.
1156
1157    # Get a valid model and other necessary parameters for training.
1158    if model_type is not None and model_type not in available_models:
1159        raise ValueError(f"'{model_type}' is not a valid choice of model.")
1160    if config is not None and config not in available_configurations:
1161        raise ValueError(f"'{config}' is not a valid choice of configuration.")
1162
1163    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
1164        model_type = CONFIGURATIONS[config]["model_type"]
1165
1166    train_sam_for_configuration(
1167        name=checkpoint_name,
1168        configuration=config,
1169        model_type=model_type,
1170        train_loader=train_loader,
1171        val_loader=val_loader,
1172        n_epochs=epochs,
1173        checkpoint_path=checkpoint_path,
1174        with_segmentation_decoder=with_segmentation_decoder,
1175        freeze=None,  # TODO: Allow for PEFT.
1176        device=device,
1177        save_root=save_root,
1178        peft_kwargs=None,  # TODO: Allow for PEFT.
1179        train_instance_segmentation_only=train_instance_segmentation_only,
1180    )
1181
1182    # 4. Export the model, if desired by the user
1183    if train_instance_segmentation_only and output_path:
1184        trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt")
1185        export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path)
1186        final_path = output_path
1187    else:
1188        final_path = _export_helper(
1189            save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader,
1190        )
1191
1192    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, overwrite_training: bool = True, **model_kwargs) -> None:
197def train_sam(
198    name: str,
199    model_type: str,
200    train_loader: DataLoader,
201    val_loader: DataLoader,
202    n_epochs: int = 100,
203    early_stopping: Optional[int] = 10,
204    n_objects_per_batch: Optional[int] = 25,
205    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
206    with_segmentation_decoder: bool = True,
207    freeze: Optional[List[str]] = None,
208    device: Optional[Union[str, torch.device]] = None,
209    lr: float = 1e-5,
210    n_sub_iteration: int = 8,
211    save_root: Optional[Union[str, os.PathLike]] = None,
212    mask_prob: float = 0.5,
213    n_iterations: Optional[int] = None,
214    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
215    scheduler_kwargs: Optional[Dict[str, Any]] = None,
216    save_every_kth_epoch: Optional[int] = None,
217    pbar_signals: Optional[QObject] = None,
218    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
219    peft_kwargs: Optional[Dict] = None,
220    ignore_warnings: bool = True,
221    verify_n_labels_in_loader: Optional[int] = 50,
222    box_distortion_factor: Optional[float] = 0.025,
223    overwrite_training: bool = True,
224    **model_kwargs,
225) -> None:
226    """Run training for a SAM model.
227
228    Args:
229        name: The name of the model to be trained. The checkpoint and logs will have this name.
230        model_type: The type of the SAM model.
231        train_loader: The dataloader for training.
232        val_loader: The dataloader for validation.
233        n_epochs: The number of epochs to train for.
234        early_stopping: Enable early stopping after this number of epochs without improvement.
235            By default, the value is set to '10' epochs.
236        n_objects_per_batch: The number of objects per batch used to compute
237            the loss for interative segmentation. If None all objects will be used,
238            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
239        checkpoint_path: Path to checkpoint for initializing the SAM model.
240        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
241            By default, trains with the additional instance segmentation decoder.
242        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
243            By default nothing is frozen and the full model is updated.
244        device: The device to use for training. By default, automatically chooses the best available device to train.
245        lr: The learning rate. By default, set to '1e-5'.
246        n_sub_iteration: The number of iterative prompts per training iteration.
247            By default, the number of iterations is set to '8'.
248        save_root: Optional root directory for saving the checkpoints and logs.
249            If not given the current working directory is used.
250        mask_prob: The probability for using a mask as input in a given training sub-iteration.
251            By default, set to '0.5'.
252        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
253        scheduler_class: The learning rate scheduler to update the learning rate.
254            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
255        scheduler_kwargs: The learning rate scheduler parameters.
256            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
257        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
258        pbar_signals: Controls for napari progress bar.
259        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
260        peft_kwargs: Keyword arguments for the PEFT wrapper class.
261        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
262        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
263            By default, 50 batches of labels are verified from the dataloaders.
264        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
265            By default, the distortion factor is set to '0.025'.
266        overwrite_training: Whether to overwrite the trained model stored at the same location.
267            By default, overwrites the trained model at each run.
268            If set to 'False', it will avoid retraining the model if the previous run was completed.
269        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
270    """
271    with _filter_warnings(ignore_warnings):
272
273        t_start = time.time()
274        if verify_n_labels_in_loader is not None:
275            _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
276            _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
277
278        device = get_device(device)
279        # Get the trainable segment anything model.
280        model, state = get_trainable_sam_model(
281            model_type=model_type,
282            device=device,
283            freeze=freeze,
284            checkpoint_path=checkpoint_path,
285            return_state=True,
286            peft_kwargs=peft_kwargs,
287            **model_kwargs
288        )
289
290        # This class creates all the training data for a batch (inputs, prompts and labels).
291        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
292
293        # Create the UNETR decoder (if train with it) and the optimizer.
294        if with_segmentation_decoder:
295
296            # Get the UNETR.
297            unetr = get_unetr(
298                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
299            )
300
301            # Get the parameters for SAM and the decoder from UNETR.
302            joint_model_params = [params for params in model.parameters()]  # sam parameters
303            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
304                if not param_name.startswith("encoder"):
305                    joint_model_params.append(params)
306
307            model_params = joint_model_params
308        else:
309            model_params = model.parameters()
310
311        optimizer, scheduler = _get_optimizer_and_scheduler(
312            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
313        )
314
315        # The trainer which performs training and validation.
316        if with_segmentation_decoder:
317            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
318            trainer = joint_trainers.JointSamTrainer(
319                name=name,
320                save_root=save_root,
321                train_loader=train_loader,
322                val_loader=val_loader,
323                model=model,
324                optimizer=optimizer,
325                device=device,
326                lr_scheduler=scheduler,
327                logger=joint_trainers.JointSamLogger,
328                log_image_interval=100,
329                mixed_precision=True,
330                convert_inputs=convert_inputs,
331                n_objects_per_batch=n_objects_per_batch,
332                n_sub_iteration=n_sub_iteration,
333                compile_model=False,
334                unetr=unetr,
335                instance_loss=instance_seg_loss,
336                instance_metric=instance_seg_loss,
337                early_stopping=early_stopping,
338                mask_prob=mask_prob,
339            )
340        else:
341            trainer = trainers.SamTrainer(
342                name=name,
343                train_loader=train_loader,
344                val_loader=val_loader,
345                model=model,
346                optimizer=optimizer,
347                device=device,
348                lr_scheduler=scheduler,
349                logger=trainers.SamLogger,
350                log_image_interval=100,
351                mixed_precision=True,
352                convert_inputs=convert_inputs,
353                n_objects_per_batch=n_objects_per_batch,
354                n_sub_iteration=n_sub_iteration,
355                compile_model=False,
356                early_stopping=early_stopping,
357                mask_prob=mask_prob,
358                save_root=save_root,
359            )
360
361        trainer_fit_params = _get_trainer_fit_params(
362            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
363        )
364        trainer.fit(**trainer_fit_params)
365
366        t_run = time.time() - t_start
367        hours = int(t_run // 3600)
368        minutes = int(t_run // 60)
369        seconds = int(round(t_run % 60, 0))
370        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def export_instance_segmentation_model( trained_model_path: Union[str, os.PathLike], output_path: Union[str, os.PathLike], model_type: str, initial_checkpoint_path: Union[os.PathLike, str, NoneType] = None) -> None:
373def export_instance_segmentation_model(
374    trained_model_path: Union[str, os.PathLike],
375    output_path: Union[str, os.PathLike],
376    model_type: str,
377    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
378) -> None:
379    """Export a model trained for instance segmentation with `train_instance_segmentation`.
380
381    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
382    It should only be used for automatic segmentation and may not work well for interactive segmentation.
383
384    Args:
385        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
386        output_path: The path where the exported model will be saved.
387        model_type: The model type.
388        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
389    """
390    trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"]
391
392    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
393    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
394    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
395
396    # Load the original state of the model that was used as the basis of instance segmentation training.
397    _, model_state = get_sam_model(
398        model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu",
399    )
400    # Remove the sam prefix if it's in the model state.
401    prefix = "sam."
402    model_state = OrderedDict(
403        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
404    )
405
406    # Replace the image encoder state.
407    model_state = OrderedDict(
408        [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v)
409         for k, v in model_state.items()]
410    )
411
412    save_state = {"model_state": model_state, "decoder_state": decoder_state}
413    torch.save(save_state, output_path)

Export a model trained for instance segmentation with train_instance_segmentation.

The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.

Arguments:
  • trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
  • output_path: The path where the exported model will be saved.
  • model_type: The model type.
  • initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
def train_instance_segmentation( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, loss: torch.nn.modules.module.Module = DiceBasedDistanceLoss( (foreground_loss): DiceLoss() (distance_loss): DiceLoss() ), metric: Optional[torch.nn.modules.module.Module] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, save_root: Union[os.PathLike, str, NoneType] = None, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, overwrite_training: bool = True, **model_kwargs) -> None:
416def train_instance_segmentation(
417    name: str,
418    model_type: str,
419    train_loader: DataLoader,
420    val_loader: DataLoader,
421    n_epochs: int = 100,
422    early_stopping: Optional[int] = 10,
423    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
424    metric: Optional[torch.nn.Module] = None,
425    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
426    freeze: Optional[List[str]] = None,
427    device: Optional[Union[str, torch.device]] = None,
428    lr: float = 1e-5,
429    save_root: Optional[Union[str, os.PathLike]] = None,
430    n_iterations: Optional[int] = None,
431    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
432    scheduler_kwargs: Optional[Dict[str, Any]] = None,
433    save_every_kth_epoch: Optional[int] = None,
434    pbar_signals: Optional[QObject] = None,
435    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
436    peft_kwargs: Optional[Dict] = None,
437    ignore_warnings: bool = True,
438    overwrite_training: bool = True,
439    **model_kwargs,
440) -> None:
441    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
442
443    This setting corresponds to training a SAM model with an instance segmentation decoder,
444    without training the model parts for interactive segmentation,
445    i.e. without training the prompt encoder and mask decoder.
446
447    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
448    will not be compatible with the micro_sam functionality.
449    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
450    in a format that is compatible with micro_sam functionality.
451    Note that the exported model should only be used for automatic segmentation via AIS.
452
453    Args:
454        name: The name of the model to be trained. The checkpoint and logs will have this name.
455        model_type: The type of the SAM model.
456        train_loader: The dataloader for training.
457        val_loader: The dataloader for validation.
458        n_epochs: The number of epochs to train for.
459        early_stopping: Enable early stopping after this number of epochs without improvement.
460            By default, the value is set to '10' epochs.
461        loss: The loss function to train the instance segmentation model.
462            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
463        metric: The metric for the instance segmentation training.
464            By default the loss function is used as the metric.
465        checkpoint_path: Path to checkpoint for initializing the SAM model.
466        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
467            By default nothing is frozen and the full model is updated.
468        device: The device to use for training. By default, automatically chooses the best available device to train.
469        lr: The learning rate. By default, set to '1e-5'.
470        save_root: Optional root directory for saving the checkpoints and logs.
471            If not given the current working directory is used.
472        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
473        scheduler_class: The learning rate scheduler to update the learning rate.
474            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
475        scheduler_kwargs: The learning rate scheduler parameters.
476            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
477        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
478        pbar_signals: Controls for napari progress bar.
479        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
480        peft_kwargs: Keyword arguments for the PEFT wrapper class.
481        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
482        overwrite_training: Whether to overwrite the trained model stored at the same location.
483            By default, overwrites the trained model at each run.
484            If set to 'False', it will avoid retraining the model if the previous run was completed.
485        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
486    """
487
488    with _filter_warnings(ignore_warnings):
489        t_start = time.time()
490
491        sam_model, state = get_trainable_sam_model(
492            model_type=model_type,
493            device=device,
494            checkpoint_path=checkpoint_path,
495            return_state=True,
496            peft_kwargs=peft_kwargs,
497            freeze=freeze,
498            **model_kwargs
499        )
500        device = get_device(device)
501        model = get_unetr(
502            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device,
503        )
504
505        optimizer, scheduler = _get_optimizer_and_scheduler(
506            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
507        )
508        trainer = torch_em.trainer.DefaultTrainer(
509            name=name,
510            model=model,
511            train_loader=train_loader,
512            val_loader=val_loader,
513            device=device,
514            mixed_precision=True,
515            log_image_interval=50,
516            compile_model=False,
517            save_root=save_root,
518            loss=loss,
519            metric=loss if metric is None else metric,
520            optimizer=optimizer,
521            lr_scheduler=scheduler,
522            early_stopping=early_stopping,
523        )
524
525        trainer_fit_params = _get_trainer_fit_params(
526            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
527        )
528        trainer.fit(**trainer_fit_params)
529
530        t_run = time.time() - t_start
531        hours = int(t_run // 3600)
532        minutes = int(t_run // 60)
533        seconds = int(round(t_run % 60, 0))
534        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Train a UNETR for instance segmentation using the SAM encoder as backbone.

This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.

The checkpoint of the trained model, which will be saved in 'checkpoints/', will not be compatible with the micro_sam functionality. You can call the function export_instance_segmentation_model with the path to the checkpoint to export it in a format that is compatible with micro_sam functionality. Note that the exported model should only be used for automatic segmentation via AIS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
  • metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[numpy.ndarray, torch.Tensor]], List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[numpy.ndarray, torch.Tensor]], List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, train_instance_segmentation_only: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, rois: Union[slice, Tuple[slice, ...], NoneType] = None, is_multi_tensor: bool = True, **kwargs) -> torch.utils.data.dataset.Dataset:
615def default_sam_dataset(
616    raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
617    raw_key: Optional[str],
618    label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
619    label_key: Optional[str],
620    patch_shape: Tuple[int],
621    with_segmentation_decoder: bool,
622    with_channels: Optional[bool] = None,
623    train_instance_segmentation_only: bool = False,
624    sampler: Optional[Callable] = None,
625    raw_transform: Optional[Callable] = None,
626    n_samples: Optional[int] = None,
627    is_train: bool = True,
628    min_size: int = 25,
629    max_sampling_attempts: Optional[int] = None,
630    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
631    is_multi_tensor: bool = True,
632    **kwargs,
633) -> Dataset:
634    """Create a PyTorch Dataset for training a SAM model.
635
636    Args:
637        raw_paths: The path(s) to the image data used for training.
638            Can either be multiple 2D images or volumetric data.
639            The data can also be passed as a list of numpy arrays or torch tensors.
640        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
641            or a glob pattern for selecting multiple files.
642            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
643        label_paths: The path(s) to the label data used for training.
644            Can either be multiple 2D images or volumetric data.
645            The data can also be passed as a list of numpy arrays or torch tensors.
646        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
647            or a glob pattern for selecting multiple files.
648            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
649        patch_shape: The shape for training patches.
650        with_segmentation_decoder: Whether to train with additional segmentation decoder.
651        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
652        train_instance_segmentation_only: Set this argument to True in order to
653            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
654        sampler: A sampler to reject batches according to a given criterion.
655        raw_transform: Transformation applied to the image data.
656            If not given the data will be cast to 8bit.
657        n_samples: The number of samples for this dataset.
658        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
659        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
660        max_sampling_attempts: Number of sampling attempts to make from a dataset.
661        rois: The region of interest(s) for the data.
662        is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
663        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
664
665    Returns:
666        The segmentation dataset.
667    """
668
669    # Check if this dataset should be used for instance segmentation only training.
670    # If yes, we set return_instances to False, since the instance channel must not
671    # be passed for this training mode.
672    return_instances = True
673    if train_instance_segmentation_only:
674        if not with_segmentation_decoder:
675            raise ValueError(
676                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
677            )
678        return_instances = False
679
680    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
681    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
682    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
683    # and we do not set a default sampler.
684    if sampler is None and not train_instance_segmentation_only:
685        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
686
687    is_seg_dataset, with_channels = _determine_dataset_and_channels(
688        raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs
689    )
690
691    # Set the data transformations.
692    if raw_transform is None:
693        raw_transform = require_8bit
694
695    # Prepare the label transform.
696    if with_segmentation_decoder:
697        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
698            distances=True,
699            boundary_distances=True,
700            directed_distances=False,
701            foreground=True,
702            instances=return_instances,
703            min_size=min_size,
704        )
705    else:
706        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
707
708    # Allow combining label transforms.
709    custom_label_transform = kwargs.pop("label_transform", None)
710    if custom_label_transform is None:
711        label_transform = default_label_transform
712    else:
713        label_transform = torch_em.transform.generic.Compose(
714            custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
715        )
716
717    # Check the patch shape to add a singleton if required.
718    patch_shape = _update_patch_shape(
719        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
720    )
721
722    # Set a minimum number of samples per epoch.
723    if n_samples is None:
724        loader = torch_em.default_segmentation_loader(
725            raw_paths=raw_paths,
726            raw_key=raw_key,
727            label_paths=label_paths,
728            label_key=label_key,
729            batch_size=1,
730            patch_shape=patch_shape,
731            with_channels=with_channels,
732            ndim=2,
733            is_seg_dataset=is_seg_dataset,
734            raw_transform=raw_transform,
735            rois=rois,
736            **kwargs
737        )
738        n_samples = max(len(loader), 100 if is_train else 5)
739
740    dataset = torch_em.default_segmentation_dataset(
741        raw_paths=raw_paths,
742        raw_key=raw_key,
743        label_paths=label_paths,
744        label_key=label_key,
745        patch_shape=patch_shape,
746        raw_transform=raw_transform,
747        label_transform=label_transform,
748        with_channels=with_channels,
749        ndim=2,
750        sampler=sampler,
751        n_samples=n_samples,
752        is_seg_dataset=is_seg_dataset,
753        rois=rois,
754        **kwargs,
755    )
756
757    if max_sampling_attempts is not None:
758        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
759            for ds in dataset.datasets:
760                ds.max_sampling_attempts = max_sampling_attempts
761        else:
762            dataset.max_sampling_attempts = max_sampling_attempts
763
764    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • train_instance_segmentation_only: Set this argument to True in order to pass the dataset to train_instance_segmentation. By default, set to 'False'.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
  • min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • rois: The region of interest(s) for the data.
  • is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
767def default_sam_loader(**kwargs) -> DataLoader:
768    """Create a PyTorch DataLoader for training a SAM model.
769
770    Args:
771        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
772
773    Returns:
774        The DataLoader.
775    """
776    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
777
778    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
779    # which the users can provide to get their desired segmentation dataset.
780    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
781    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
782
783    ds = default_sam_dataset(**ds_kwargs)
784    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'gtx3080': {'model_type': 'vit_b', 'n_objects_per_batch': 5, 'peft_kwargs': {'attention_layers_to_update': [11], 'peft_module': <class 'micro_sam.models.peft_sam.ClassicalSurgery'>}}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

def train_sam_for_configuration( name: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, configuration: Optional[str] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, train_instance_segmentation_only: bool = False, model_type: Optional[str] = None, **kwargs) -> None:
826def train_sam_for_configuration(
827    name: str,
828    train_loader: DataLoader,
829    val_loader: DataLoader,
830    configuration: Optional[str] = None,
831    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
832    with_segmentation_decoder: bool = True,
833    train_instance_segmentation_only: bool = False,
834    model_type: Optional[str] = None,
835    **kwargs,
836) -> None:
837    """Run training for a SAM model with the configuration for a given hardware resource.
838
839    Selects the best training settings for the given configuration.
840    The available configurations are listed in `CONFIGURATIONS`.
841
842    Args:
843        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
844        train_loader: The dataloader for training.
845        val_loader: The dataloader for validation.
846        configuration: The configuration (= name of hardware resource).
847            By default, it is automatically selected for the best VRAM combination.
848        checkpoint_path: Path to checkpoint for initializing the SAM model.
849        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
850            By default, trains with the additional instance segmentation decoder.
851        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
852            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
853        model_type: Over-ride the default model type.
854            This can be used to use one of the micro_sam models as starting point
855            instead of a default sam model.
856        kwargs: Additional keyword parameters that will be passed to `train_sam`.
857    """
858    if configuration is None:  # Automatically choose based on available VRAM combination.
859        configuration = _find_best_configuration()
860
861    if configuration in CONFIGURATIONS:
862        train_kwargs = CONFIGURATIONS[configuration]
863    else:
864        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
865
866    if model_type is None:
867        model_type = train_kwargs.pop("model_type")
868    else:
869        expected_model_type = train_kwargs.pop("model_type")
870        if model_type[:5] != expected_model_type:
871            warnings.warn("You have specified a different model type.")
872
873    train_kwargs.update(**kwargs)
874    if train_instance_segmentation_only:
875        instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs)
876        model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs)
877        instance_seg_kwargs.update(**model_kwargs)
878
879        train_instance_segmentation(
880            name=name,
881            train_loader=train_loader,
882            val_loader=val_loader,
883            checkpoint_path=checkpoint_path,
884            model_type=model_type,
885            **instance_seg_kwargs,
886        )
887    else:
888        train_sam(
889            name=name,
890            train_loader=train_loader,
891            val_loader=val_loader,
892            checkpoint_path=checkpoint_path,
893            with_segmentation_decoder=with_segmentation_decoder,
894            model_type=model_type,
895            **train_kwargs
896        )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs folder will have this name.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation using the training implementation train_instance_segmentation. By default, train_sam is used.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.