micro_sam.training.training

   1import os
   2import time
   3import warnings
   4from glob import glob
   5from tqdm import tqdm
   6from collections import OrderedDict
   7from contextlib import contextmanager, nullcontext
   8from typing import Any, Callable, Dict, List, Optional, Tuple, Union
   9
  10import imageio.v3 as imageio
  11import numpy as np
  12import torch
  13from torch.optim import Optimizer
  14from torch.utils.data import random_split
  15from torch.utils.data import DataLoader, Dataset
  16from torch.optim.lr_scheduler import _LRScheduler
  17
  18import torch_em
  19from torch_em.util import load_data
  20from torch_em.data.datasets.util import split_kwargs
  21
  22from elf.io import open_file
  23
  24try:
  25    from qtpy.QtCore import QObject
  26except Exception:
  27    QObject = Any
  28
  29from . import sam_trainer as trainers
  30from ..instance_segmentation import get_unetr
  31from ..models.peft_sam import ClassicalSurgery
  32from . import joint_sam_trainer as joint_trainers
  33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model
  34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
  35
  36
  37FilePath = Union[str, os.PathLike]
  38
  39
  40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
  41    x, _ = next(iter(loader))
  42
  43    # Raw data: check that we have 1 or 3 channels.
  44    n_channels = x.shape[1]
  45    if n_channels not in (1, 3):
  46        raise ValueError(
  47            "Invalid number of channels for the input data from the data loader. "
  48            f"Expect 1 or 3 channels, got {n_channels}."
  49        )
  50
  51    # Raw data: check that it is between [0, 255]
  52    minval, maxval = x.min(), x.max()
  53    if minval < 0 or minval > 255:
  54        raise ValueError(
  55            "Invalid data range for the input data from the data loader. "
  56            f"The input has to be in range [0, 255], but got minimum value {minval}."
  57        )
  58    if maxval < 1 or maxval > 255:
  59        raise ValueError(
  60            "Invalid data range for the input data from the data loader. "
  61            f"The input has to be in range [0, 255], but got maximum value {maxval}."
  62        )
  63
  64    # Target data: the check depends on whether we train with or without decoder.
  65    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
  66
  67    def _check_instance_channel(instance_channel):
  68        unique_vals = torch.unique(instance_channel)
  69        if (unique_vals < 0).any():
  70            raise ValueError(
  71                "The target channel with the instance segmentation must not have negative values."
  72            )
  73        if len(unique_vals) == 1:
  74            raise ValueError(
  75                "The target channel with the instance segmentation must have at least one instance."
  76            )
  77        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
  78            raise ValueError(
  79                "All values in the target channel with the instance segmentation must be integer."
  80            )
  81
  82    counter = 0
  83    name = "" if name is None else f"'{name}'"
  84    for x, y in tqdm(
  85        loader,
  86        desc=f"Verifying labels in {name} dataloader",
  87        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
  88    ):
  89        n_channels_y = y.shape[1]
  90        if with_segmentation_decoder:
  91            if n_channels_y != 4:
  92                raise ValueError(
  93                    "Invalid number of channels in the target data from the data loader. "
  94                    "Expect 4 channel for training with an instance segmentation decoder, "
  95                    f"but got {n_channels_y} channels."
  96                )
  97            # Check instance channel per sample in a batch
  98            for per_y_sample in y:
  99                _check_instance_channel(per_y_sample[0])
 100
 101            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 102            if targets_min < 0 or targets_min > 1:
 103                raise ValueError(
 104                    "Invalid value range in the target data from the value loader. "
 105                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 106                    f"to be in range [0.0, 1.0], but got min {targets_min}"
 107                )
 108            if targets_max < 0 or targets_max > 1:
 109                raise ValueError(
 110                    "Invalid value range in the target data from the value loader. "
 111                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
 112                    f"to be in range [0.0, 1.0], but got max {targets_max}"
 113                )
 114
 115        else:
 116            if n_channels_y != 1:
 117                raise ValueError(
 118                    "Invalid number of channels in the target data from the data loader. "
 119                    "Expect 1 channel for training without an instance segmentation decoder, "
 120                    f"but got {n_channels_y} channels."
 121                )
 122            # Check instance channel per sample in a batch
 123            for per_y_sample in y:
 124                _check_instance_channel(per_y_sample)
 125
 126        counter += 1
 127        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
 128            break
 129
 130
 131# Make the progress bar callbacks compatible with a tqdm progress bar interface.
 132class _ProgressBarWrapper:
 133    def __init__(self, signals):
 134        self._signals = signals
 135        self._total = None
 136
 137    @property
 138    def total(self):
 139        return self._total
 140
 141    @total.setter
 142    def total(self, value):
 143        self._signals.pbar_total.emit(value)
 144        self._total = value
 145
 146    def update(self, steps):
 147        self._signals.pbar_update.emit(steps)
 148
 149    def set_description(self, desc, **kwargs):
 150        self._signals.pbar_description.emit(desc)
 151
 152
 153@contextmanager
 154def _filter_warnings(ignore_warnings):
 155    if ignore_warnings:
 156        with warnings.catch_warnings():
 157            warnings.simplefilter("ignore")
 158            yield
 159    else:
 160        with nullcontext():
 161            yield
 162
 163
 164def _count_parameters(model_parameters):
 165    params = sum(p.numel() for p in model_parameters if p.requires_grad)
 166    params = params / 1e6
 167    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
 168
 169
 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training):
 171    if n_iterations is None:
 172        trainer_fit_params = {"epochs": n_epochs}
 173    else:
 174        trainer_fit_params = {"iterations": n_iterations}
 175
 176    if save_every_kth_epoch is not None:
 177        trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
 178
 179    if pbar_signals is not None:
 180        progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
 181        trainer_fit_params["progress"] = progress_bar_wrapper
 182
 183    # Avoid overwriting a trained model, if desired by the user.
 184    trainer_fit_params["overwrite_training"] = overwrite_training
 185    return trainer_fit_params
 186
 187
 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs):
 189    optimizer = optimizer_class(model_params, lr=lr)
 190    if scheduler_kwargs is None:
 191        scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3}
 192    scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
 193    return optimizer, scheduler
 194
 195
 196def train_sam(
 197    name: str,
 198    model_type: str,
 199    train_loader: DataLoader,
 200    val_loader: DataLoader,
 201    n_epochs: int = 100,
 202    early_stopping: Optional[int] = 10,
 203    n_objects_per_batch: Optional[int] = 25,
 204    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 205    with_segmentation_decoder: bool = True,
 206    freeze: Optional[List[str]] = None,
 207    device: Optional[Union[str, torch.device]] = None,
 208    lr: float = 1e-5,
 209    n_sub_iteration: int = 8,
 210    save_root: Optional[Union[str, os.PathLike]] = None,
 211    mask_prob: float = 0.5,
 212    n_iterations: Optional[int] = None,
 213    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 214    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 215    save_every_kth_epoch: Optional[int] = None,
 216    pbar_signals: Optional[QObject] = None,
 217    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 218    peft_kwargs: Optional[Dict] = None,
 219    ignore_warnings: bool = True,
 220    verify_n_labels_in_loader: Optional[int] = 50,
 221    box_distortion_factor: Optional[float] = 0.025,
 222    overwrite_training: bool = True,
 223    strict_decoder_loading: bool = True,
 224    **model_kwargs,
 225) -> None:
 226    """Run training for a SAM model.
 227
 228    Args:
 229        name: The name of the model to be trained. The checkpoint and logs will have this name.
 230        model_type: The type of the SAM model.
 231        train_loader: The dataloader for training.
 232        val_loader: The dataloader for validation.
 233        n_epochs: The number of epochs to train for.
 234        early_stopping: Enable early stopping after this number of epochs without improvement.
 235            By default, the value is set to '10' epochs.
 236        n_objects_per_batch: The number of objects per batch used to compute
 237            the loss for interative segmentation. If None all objects will be used,
 238            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
 239        checkpoint_path: Path to checkpoint for initializing the SAM model.
 240        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 241            By default, trains with the additional instance segmentation decoder.
 242        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
 243            By default nothing is frozen and the full model is updated.
 244        device: The device to use for training. By default, automatically chooses the best available device to train.
 245        lr: The learning rate. By default, set to '1e-5'.
 246        n_sub_iteration: The number of iterative prompts per training iteration.
 247            By default, the number of iterations is set to '8'.
 248        save_root: Optional root directory for saving the checkpoints and logs.
 249            If not given the current working directory is used.
 250        mask_prob: The probability for using a mask as input in a given training sub-iteration.
 251            By default, set to '0.5'.
 252        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 253        scheduler_class: The learning rate scheduler to update the learning rate.
 254            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 255        scheduler_kwargs: The learning rate scheduler parameters.
 256            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 257        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 258        pbar_signals: Controls for napari progress bar.
 259        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 260        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 261        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 262        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
 263            By default, 50 batches of labels are verified from the dataloaders.
 264        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
 265            By default, the distortion factor is set to '0.025'.
 266        overwrite_training: Whether to overwrite the trained model stored at the same location.
 267            By default, overwrites the trained model at each run.
 268            If set to 'False', it will avoid retraining the model if the previous run was completed.
 269        strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
 270            exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
 271            channels if they were pre-trained for a different task. If set to False, decoders with a different
 272            output dimension can be loaded; the output channels will be re-initialized.
 273        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 274    """
 275    with _filter_warnings(ignore_warnings):
 276
 277        t_start = time.time()
 278        if verify_n_labels_in_loader is not None:
 279            _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
 280            _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
 281
 282        device = get_device(device)
 283        # Get the trainable segment anything model.
 284        model, state = get_trainable_sam_model(
 285            model_type=model_type,
 286            device=device,
 287            freeze=freeze,
 288            checkpoint_path=checkpoint_path,
 289            return_state=True,
 290            peft_kwargs=peft_kwargs,
 291            **model_kwargs
 292        )
 293
 294        # This class creates all the training data for a batch (inputs, prompts and labels).
 295        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
 296
 297        # Create the UNETR decoder (if train with it) and the optimizer.
 298        if with_segmentation_decoder:
 299
 300            # Get the UNETR.
 301            unetr = get_unetr(
 302                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
 303                device=device, flexible_load_checkpoint=not strict_decoder_loading,
 304            )
 305
 306            # Get the parameters for SAM and the decoder from UNETR.
 307            joint_model_params = [params for params in model.parameters()]  # sam parameters
 308            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
 309                if not param_name.startswith("encoder"):
 310                    joint_model_params.append(params)
 311
 312            model_params = joint_model_params
 313        else:
 314            model_params = model.parameters()
 315
 316        optimizer, scheduler = _get_optimizer_and_scheduler(
 317            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
 318        )
 319
 320        # The trainer which performs training and validation.
 321        if with_segmentation_decoder:
 322            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
 323            trainer = joint_trainers.JointSamTrainer(
 324                name=name,
 325                save_root=save_root,
 326                train_loader=train_loader,
 327                val_loader=val_loader,
 328                model=model,
 329                optimizer=optimizer,
 330                device=device,
 331                lr_scheduler=scheduler,
 332                logger=joint_trainers.JointSamLogger,
 333                log_image_interval=100,
 334                mixed_precision=True,
 335                convert_inputs=convert_inputs,
 336                n_objects_per_batch=n_objects_per_batch,
 337                n_sub_iteration=n_sub_iteration,
 338                compile_model=False,
 339                unetr=unetr,
 340                instance_loss=instance_seg_loss,
 341                instance_metric=instance_seg_loss,
 342                early_stopping=early_stopping,
 343                mask_prob=mask_prob,
 344            )
 345        else:
 346            trainer = trainers.SamTrainer(
 347                name=name,
 348                train_loader=train_loader,
 349                val_loader=val_loader,
 350                model=model,
 351                optimizer=optimizer,
 352                device=device,
 353                lr_scheduler=scheduler,
 354                logger=trainers.SamLogger,
 355                log_image_interval=100,
 356                mixed_precision=True,
 357                convert_inputs=convert_inputs,
 358                n_objects_per_batch=n_objects_per_batch,
 359                n_sub_iteration=n_sub_iteration,
 360                compile_model=False,
 361                early_stopping=early_stopping,
 362                mask_prob=mask_prob,
 363                save_root=save_root,
 364            )
 365
 366        trainer_fit_params = _get_trainer_fit_params(
 367            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 368        )
 369        trainer.fit(**trainer_fit_params)
 370
 371        t_run = time.time() - t_start
 372        hours = int(t_run // 3600)
 373        minutes = int(t_run // 60)
 374        seconds = int(round(t_run % 60, 0))
 375        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 376
 377
 378def export_instance_segmentation_model(
 379    trained_model_path: Union[str, os.PathLike],
 380    output_path: Union[str, os.PathLike],
 381    model_type: str,
 382    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 383) -> None:
 384    """Export a model trained for instance segmentation with `train_instance_segmentation`.
 385
 386    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
 387    It should only be used for automatic segmentation and may not work well for interactive segmentation.
 388
 389    Args:
 390        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
 391        output_path: The path where the exported model will be saved.
 392        model_type: The model type.
 393        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
 394    """
 395    state = torch.load(trained_model_path, weights_only=False, map_location="cpu")
 396    trained_state = state.get("model_state", state)
 397
 398    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
 399    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
 400    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
 401
 402    # Load the original state of the model that was used as the basis of instance segmentation training.
 403    predictor = get_sam_model(
 404        model_type=model_type, checkpoint_path=initial_checkpoint_path, device="cpu",
 405    )
 406    model_state = OrderedDict(predictor.model.state_dict())
 407
 408    # Replace the image encoder weights with the trained ones.
 409    # UNETR stores the image encoder as "encoder.*"; SAM uses "image_encoder.*".
 410    for k in list(model_state.keys()):
 411        if k.startswith("image_encoder."):
 412            encoder_key = "encoder." + k[len("image_encoder."):]
 413            model_state[k] = encoder_state[encoder_key]
 414
 415    save_state = {"model_state": model_state, "decoder_state": decoder_state}
 416    torch.save(save_state, output_path)
 417
 418
 419def train_instance_segmentation(
 420    name: str,
 421    model_type: str,
 422    train_loader: DataLoader,
 423    val_loader: DataLoader,
 424    n_epochs: int = 100,
 425    early_stopping: Optional[int] = 10,
 426    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
 427    metric: Optional[torch.nn.Module] = None,
 428    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 429    freeze: Optional[List[str]] = None,
 430    device: Optional[Union[str, torch.device]] = None,
 431    lr: float = 1e-5,
 432    save_root: Optional[Union[str, os.PathLike]] = None,
 433    n_iterations: Optional[int] = None,
 434    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
 435    scheduler_kwargs: Optional[Dict[str, Any]] = None,
 436    save_every_kth_epoch: Optional[int] = None,
 437    pbar_signals: Optional[QObject] = None,
 438    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
 439    peft_kwargs: Optional[Dict] = None,
 440    ignore_warnings: bool = True,
 441    overwrite_training: bool = True,
 442    strict_decoder_loading: bool = True,
 443    **model_kwargs,
 444) -> None:
 445    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
 446
 447    This setting corresponds to training a SAM model with an instance segmentation decoder,
 448    without training the model parts for interactive segmentation,
 449    i.e. without training the prompt encoder and mask decoder.
 450
 451    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
 452    will not be compatible with the micro_sam functionality.
 453    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
 454    in a format that is compatible with micro_sam functionality.
 455    Note that the exported model should only be used for automatic segmentation via AIS.
 456
 457    Args:
 458        name: The name of the model to be trained. The checkpoint and logs will have this name.
 459        model_type: The type of the SAM model.
 460        train_loader: The dataloader for training.
 461        val_loader: The dataloader for validation.
 462        n_epochs: The number of epochs to train for.
 463        early_stopping: Enable early stopping after this number of epochs without improvement.
 464            By default, the value is set to '10' epochs.
 465        loss: The loss function to train the instance segmentation model.
 466            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
 467        metric: The metric for the instance segmentation training.
 468            By default the loss function is used as the metric.
 469        checkpoint_path: Path to checkpoint for initializing the SAM model.
 470        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
 471            By default nothing is frozen and the full model is updated.
 472        device: The device to use for training. By default, automatically chooses the best available device to train.
 473        lr: The learning rate. By default, set to '1e-5'.
 474        save_root: Optional root directory for saving the checkpoints and logs.
 475            If not given the current working directory is used.
 476        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
 477        scheduler_class: The learning rate scheduler to update the learning rate.
 478            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
 479        scheduler_kwargs: The learning rate scheduler parameters.
 480            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
 481        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
 482        pbar_signals: Controls for napari progress bar.
 483        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
 484        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 485        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
 486        overwrite_training: Whether to overwrite the trained model stored at the same location.
 487            By default, overwrites the trained model at each run.
 488            If set to 'False', it will avoid retraining the model if the previous run was completed.
 489        strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
 490            exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
 491            channels if they were pre-trained for a different task. If set to False, decoders with a different
 492            output dimension can be loaded; the output channels will be re-initialized.
 493        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
 494    """
 495
 496    with _filter_warnings(ignore_warnings):
 497        t_start = time.time()
 498
 499        sam_model, state = get_trainable_sam_model(
 500            model_type=model_type,
 501            device=device,
 502            checkpoint_path=checkpoint_path,
 503            return_state=True,
 504            peft_kwargs=peft_kwargs,
 505            freeze=freeze,
 506            **model_kwargs
 507        )
 508        device = get_device(device)
 509        model = get_unetr(
 510            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
 511            device=device, flexible_load_checkpoint=not strict_decoder_loading,
 512        )
 513
 514        optimizer, scheduler = _get_optimizer_and_scheduler(
 515            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
 516        )
 517        trainer = torch_em.trainer.DefaultTrainer(
 518            name=name,
 519            model=model,
 520            train_loader=train_loader,
 521            val_loader=val_loader,
 522            device=device,
 523            mixed_precision=True,
 524            log_image_interval=50,
 525            compile_model=False,
 526            save_root=save_root,
 527            loss=loss,
 528            metric=loss if metric is None else metric,
 529            optimizer=optimizer,
 530            lr_scheduler=scheduler,
 531            early_stopping=early_stopping,
 532        )
 533
 534        trainer_fit_params = _get_trainer_fit_params(
 535            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
 536        )
 537        trainer.fit(**trainer_fit_params)
 538
 539        t_run = time.time() - t_start
 540        hours = int(t_run // 3600)
 541        minutes = int(t_run // 60)
 542        seconds = int(round(t_run % 60, 0))
 543        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
 544
 545
 546def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
 547    if isinstance(raw_paths, (str, os.PathLike)):
 548        path = raw_paths
 549    else:
 550        path = raw_paths[0]
 551    assert isinstance(path, (str, os.PathLike, np.ndarray, torch.Tensor))
 552
 553    # Check the underlying data dimensionality.
 554    if raw_key is None and isinstance(path, (np.ndarray, torch.Tensor)):
 555        ndim = path.ndim
 556    elif raw_key is None:  # If no key is given and this is not a tensor/array then we assume it's an image file.
 557        ndim = imageio.imread(path).ndim
 558    else:  # Otherwise we try to open the file from key.
 559        try:  # First try to open it with elf.
 560            with open_file(path, "r") as f:
 561                ndim = f[raw_key].ndim
 562        except ValueError:  # This may fail for images in a folder with different sizes.
 563            # In that case we read one of the images.
 564            image_path = glob(os.path.join(path, raw_key))[0]
 565            ndim = imageio.imread(image_path).ndim
 566
 567    if not isinstance(patch_shape, tuple):
 568        patch_shape = tuple(patch_shape)
 569
 570    if ndim == 2:
 571        assert len(patch_shape) == 2
 572        return patch_shape
 573    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
 574        return (1,) + patch_shape
 575    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
 576        return (1,) + patch_shape
 577    else:
 578        return patch_shape
 579
 580
 581def _determine_dataset_and_channels(raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs):
 582    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
 583    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
 584
 585    # Check if the input data is in numpy or torch. In this case determine we use
 586    # the image collection dataset heuristic (torch_em will figure it out),
 587    # and we determine the number of channels.
 588    if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)):
 589        is_seg_dataset = False
 590        if with_channels is None:
 591            with_channels = raw_paths[0].ndim == 3 and raw_paths.shape[-1] == 3
 592        return is_seg_dataset, with_channels
 593
 594    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
 595    # Get valid raw paths to make checks possible.
 596    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
 597        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
 598    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
 599        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
 600
 601    # Load one of the raw inputs to validate whether it is RGB or not.
 602    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
 603    if test_raw_inputs.ndim == 3:
 604        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
 605            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
 606            # We need to provide a list of inputs to 'ImageCollectionDataset'.
 607            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
 608            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
 609
 610            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
 611            with_channels = False if with_channels is None else with_channels
 612
 613        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
 614            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
 615            with_channels = True if with_channels is None else with_channels
 616
 617    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
 618    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
 619    with_channels = False if with_channels is None else with_channels
 620
 621    return is_seg_dataset, with_channels
 622
 623
 624def default_sam_dataset(
 625    raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
 626    raw_key: Optional[str],
 627    label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
 628    label_key: Optional[str],
 629    patch_shape: Tuple[int],
 630    with_segmentation_decoder: bool,
 631    with_channels: Optional[bool] = None,
 632    train_instance_segmentation_only: bool = False,
 633    sampler: Optional[Callable] = None,
 634    raw_transform: Optional[Callable] = None,
 635    n_samples: Optional[int] = None,
 636    is_train: bool = True,
 637    min_size: int = 25,
 638    max_sampling_attempts: Optional[int] = None,
 639    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
 640    is_multi_tensor: bool = True,
 641    **kwargs,
 642) -> Dataset:
 643    """Create a PyTorch Dataset for training a SAM model.
 644
 645    Args:
 646        raw_paths: The path(s) to the image data used for training.
 647            Can either be multiple 2D images or volumetric data.
 648            The data can also be passed as a list of numpy arrays or torch tensors.
 649        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
 650            or a glob pattern for selecting multiple files.
 651            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
 652        label_paths: The path(s) to the label data used for training.
 653            Can either be multiple 2D images or volumetric data.
 654            The data can also be passed as a list of numpy arrays or torch tensors.
 655        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
 656            or a glob pattern for selecting multiple files.
 657            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
 658        patch_shape: The shape for training patches.
 659        with_segmentation_decoder: Whether to train with additional segmentation decoder.
 660        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
 661        train_instance_segmentation_only: Set this argument to True in order to
 662            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
 663        sampler: A sampler to reject batches according to a given criterion.
 664        raw_transform: Transformation applied to the image data.
 665            If not given the data will be cast to 8bit.
 666        n_samples: The number of samples for this dataset.
 667        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
 668        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
 669        max_sampling_attempts: Number of sampling attempts to make from a dataset.
 670        rois: The region of interest(s) for the data.
 671        is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
 672        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
 673
 674    Returns:
 675        The segmentation dataset.
 676    """
 677
 678    # Check if this dataset should be used for instance segmentation only training.
 679    # If yes, we set return_instances to False, since the instance channel must not
 680    # be passed for this training mode.
 681    return_instances = True
 682    if train_instance_segmentation_only:
 683        if not with_segmentation_decoder:
 684            raise ValueError(
 685                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
 686            )
 687        return_instances = False
 688
 689    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
 690    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
 691    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
 692    # and we do not set a default sampler.
 693    if sampler is None and not train_instance_segmentation_only:
 694        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
 695
 696    is_seg_dataset, with_channels = _determine_dataset_and_channels(
 697        raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs
 698    )
 699    # Since `is_seg_dataset` is selected and sorted above, let's remove it out of running kwargs, if available.
 700    kwargs.pop("is_seg_dataset", None)
 701
 702    # Set the data transformations.
 703    if raw_transform is None:
 704        raw_transform = require_8bit
 705
 706    # Prepare the label transform.
 707    if with_segmentation_decoder:
 708        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
 709            distances=True,
 710            boundary_distances=True,
 711            directed_distances=False,
 712            foreground=True,
 713            instances=return_instances,
 714            min_size=min_size,
 715        )
 716    else:
 717        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
 718
 719    # Allow combining label transforms.
 720    custom_label_transform = kwargs.pop("label_transform", None)
 721    if custom_label_transform is None:
 722        label_transform = default_label_transform
 723    else:
 724        label_transform = torch_em.transform.generic.Compose(
 725            custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
 726        )
 727
 728    # Check the patch shape to add a singleton if required.
 729    patch_shape = _update_patch_shape(
 730        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
 731    )
 732
 733    # Set a minimum number of samples per epoch.
 734    if n_samples is None:
 735        loader = torch_em.default_segmentation_loader(
 736            raw_paths=raw_paths,
 737            raw_key=raw_key,
 738            label_paths=label_paths,
 739            label_key=label_key,
 740            batch_size=1,
 741            patch_shape=patch_shape,
 742            with_channels=with_channels,
 743            ndim=2,
 744            is_seg_dataset=is_seg_dataset,
 745            raw_transform=raw_transform,
 746            rois=rois,
 747            **kwargs
 748        )
 749        n_samples = max(len(loader), 100 if is_train else 5)
 750
 751    dataset = torch_em.default_segmentation_dataset(
 752        raw_paths=raw_paths,
 753        raw_key=raw_key,
 754        label_paths=label_paths,
 755        label_key=label_key,
 756        patch_shape=patch_shape,
 757        raw_transform=raw_transform,
 758        label_transform=label_transform,
 759        with_channels=with_channels,
 760        ndim=2,
 761        sampler=sampler,
 762        n_samples=n_samples,
 763        is_seg_dataset=is_seg_dataset,
 764        rois=rois,
 765        **kwargs,
 766    )
 767
 768    if max_sampling_attempts is not None:
 769        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
 770            for ds in dataset.datasets:
 771                ds.max_sampling_attempts = max_sampling_attempts
 772        else:
 773            dataset.max_sampling_attempts = max_sampling_attempts
 774
 775    return dataset
 776
 777
 778def default_sam_loader(**kwargs) -> DataLoader:
 779    """Create a PyTorch DataLoader for training a SAM model.
 780
 781    Args:
 782        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
 783
 784    Returns:
 785        The DataLoader.
 786    """
 787    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
 788
 789    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
 790    # which the users can provide to get their desired segmentation dataset.
 791    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
 792    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
 793
 794    ds = default_sam_dataset(**ds_kwargs)
 795    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
 796
 797
 798CONFIGURATIONS = {
 799    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
 800    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
 801    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
 802    "gtx3080": {
 803        "model_type": "vit_b", "n_objects_per_batch": 5,
 804        "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery}
 805    },
 806    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
 807    "V100": {"model_type": "vit_b"},
 808    "A100": {"model_type": "vit_h"},
 809}
 810"""Best training configurations for given hardware resources.
 811"""
 812
 813
 814def _find_best_configuration():
 815    if torch.cuda.is_available():
 816
 817        # Check how much memory we have and select the best matching GPU
 818        # for the available VRAM size.
 819        _, vram = torch.cuda.mem_get_info()
 820        vram = vram / 1e9  # in GB
 821
 822        # Maybe we can get more configurations in the future.
 823        if vram > 80:  # More than 80 GB: use the A100 configurations.
 824            return "A100"
 825        elif vram > 30:  # More than 30 GB: use the V100 configurations.
 826            return "V100"
 827        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
 828            return "rtx5000"
 829        elif vram > 8:  # More than 8 GB: use the GTX3080 configurations.
 830            return "gtx3080"
 831        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
 832            return "CPU"
 833    else:
 834        return "CPU"
 835
 836
 837def train_sam_for_configuration(
 838    name: str,
 839    train_loader: DataLoader,
 840    val_loader: DataLoader,
 841    configuration: Optional[str] = None,
 842    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 843    with_segmentation_decoder: bool = True,
 844    train_instance_segmentation_only: bool = False,
 845    model_type: Optional[str] = None,
 846    **kwargs,
 847) -> None:
 848    """Run training for a SAM model with the configuration for a given hardware resource.
 849
 850    Selects the best training settings for the given configuration.
 851    The available configurations are listed in `CONFIGURATIONS`.
 852
 853    Args:
 854        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
 855        train_loader: The dataloader for training.
 856        val_loader: The dataloader for validation.
 857        configuration: The configuration (= name of hardware resource).
 858            By default, it is automatically selected for the best VRAM combination.
 859        checkpoint_path: Path to checkpoint for initializing the SAM model.
 860        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
 861            By default, trains with the additional instance segmentation decoder.
 862        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
 863            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
 864        model_type: Over-ride the default model type.
 865            This can be used to use one of the micro_sam models as starting point
 866            instead of a default sam model.
 867        kwargs: Additional keyword parameters that will be passed to `train_sam`.
 868    """
 869    if configuration is None:  # Automatically choose based on available VRAM combination.
 870        configuration = _find_best_configuration()
 871
 872    if configuration in CONFIGURATIONS:
 873        train_kwargs = CONFIGURATIONS[configuration]
 874    else:
 875        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
 876
 877    if model_type is None:
 878        model_type = train_kwargs.pop("model_type")
 879    else:
 880        expected_model_type = train_kwargs.pop("model_type")
 881        if model_type[:5] != expected_model_type:
 882            warnings.warn("You have specified a different model type.")
 883
 884    train_kwargs.update(**kwargs)
 885    if train_instance_segmentation_only:
 886        instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs)
 887        model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs)
 888        instance_seg_kwargs.update(**model_kwargs)
 889
 890        train_instance_segmentation(
 891            name=name,
 892            train_loader=train_loader,
 893            val_loader=val_loader,
 894            checkpoint_path=checkpoint_path,
 895            model_type=model_type,
 896            **instance_seg_kwargs,
 897        )
 898    else:
 899        train_sam(
 900            name=name,
 901            train_loader=train_loader,
 902            val_loader=val_loader,
 903            checkpoint_path=checkpoint_path,
 904            with_segmentation_decoder=with_segmentation_decoder,
 905            model_type=model_type,
 906            **train_kwargs
 907        )
 908
 909
 910def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
 911
 912    # Whether the model is stored in the current working directory or in another location.
 913    if save_root is None:
 914        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
 915
 916    # Get the 'best' model checkpoint ready for export.
 917    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
 918    if not os.path.exists(best_checkpoint):
 919        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
 920
 921    # Export the model if an output path has been given.
 922    if output_path:
 923
 924        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
 925        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
 926            export_custom_sam_model(
 927                checkpoint_path=best_checkpoint,
 928                model_type=model_type[:5],
 929                save_path=output_path,
 930                with_segmentation_decoder=with_segmentation_decoder,
 931            )
 932
 933        # Otherwise we export it as bioimage.io model.
 934        else:
 935            from micro_sam.bioimageio import export_sam_model
 936
 937            # Load image and corresponding labels from the val loader.
 938            with torch.no_grad():
 939                image_data, label_data = next(iter(val_loader))
 940                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
 941
 942                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
 943                if label_data.ndim == 3:
 944                    label_data = label_data[0]  # Gets the channel with instances.
 945                assert image_data.shape == label_data.shape
 946                label_data = label_data.astype("uint32")
 947
 948                export_sam_model(
 949                    image=image_data,
 950                    label_image=label_data,
 951                    model_type=model_type[:5],
 952                    name=checkpoint_name,
 953                    output_path=output_path,
 954                    checkpoint_path=best_checkpoint,
 955                )
 956
 957        # The final path where the model has been stored.
 958        final_path = output_path
 959
 960    else:  # If no exports have been made, inform the user about the best checkpoint.
 961        final_path = best_checkpoint
 962
 963    return final_path
 964
 965
 966def _parse_segmentation_decoder(segmentation_decoder):
 967    if segmentation_decoder in ("None", "none"):
 968        with_segmentation_decoder, train_instance_segmentation_only = False, False
 969    elif segmentation_decoder == "instances":
 970        with_segmentation_decoder, train_instance_segmentation_only = True, False
 971    elif segmentation_decoder == "instances_only":
 972        with_segmentation_decoder, train_instance_segmentation_only = True, True
 973    else:
 974        raise ValueError(
 975            "The 'segmentation_decoder' argument currently supports the values:\n"
 976            f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}."
 977        )
 978    return with_segmentation_decoder, train_instance_segmentation_only
 979
 980
 981def main():
 982    """@private"""
 983    import argparse
 984
 985    available_models = list(get_model_names())
 986    available_models = ", ".join(available_models)
 987
 988    available_configurations = list(CONFIGURATIONS.keys())
 989    available_configurations = ", ".join(available_configurations)
 990
 991    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
 992
 993    # Images and labels for training.
 994    parser.add_argument(
 995        "--images", required=True, type=str, nargs="*",
 996        help="Filepath to images or the directory where the image data is stored."
 997    )
 998    parser.add_argument(
 999        "--labels", required=True, type=str, nargs="*",
1000        help="Filepath to ground-truth labels or the directory where the label data is stored."
1001    )
1002    parser.add_argument(
1003        "--image_key", type=str, default=None,
1004        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
1005    )
1006    parser.add_argument(
1007        "--label_key", type=str, default=None,
1008        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
1009    )
1010
1011    # Images and labels for validation.
1012    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
1013    # Users can choose to have their explicit validation set via this feature as well.
1014    parser.add_argument(
1015        "--val_images", type=str, nargs="*",
1016        help="Filepath to images for validation or the directory where the image data is stored."
1017    )
1018    parser.add_argument(
1019        "--val_labels", type=str, nargs="*",
1020        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
1021    )
1022    parser.add_argument(
1023        "--val_image_key", type=str, default=None,
1024        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
1025    )
1026    parser.add_argument(
1027        "--val_label_key", type=str, default=None,
1028        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
1029    )
1030
1031    # Other necessary stuff for training.
1032    parser.add_argument(
1033        "--configuration", type=str, default=_find_best_configuration(),
1034        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
1035    )
1036
1037    def none_or_str(value):
1038        if value.lower() == 'none':
1039            return None
1040        return value
1041
1042    # This could be extended to train for semantic segmentation or other options.
1043    parser.add_argument(
1044        "--segmentation_decoder", type=none_or_str, default="instances",
1045        help="Whether to finetune Segment Anything Model with an additional segmentation decoder. "
1046        "The following options are possible:\n"
1047        "- 'instances' to train with an additional decoder for automatic instance segmentation. "
1048        "  This option enables using the automatic instance segmentation (AIS) mode.\n"
1049        "- 'instances_only' to train only the instance segmentation decoder. "
1050        "  In this case the parts of SAM that are used for interactive segmentation will not be trained.\n"
1051        "- 'None' to train without an additional segmentation decoder."
1052        "  This options trains only the parts of the original SAM.\n"
1053        "By default the option 'instances' is used."
1054    )
1055
1056    # Optional advanced settings a user can opt to change the values for.
1057    parser.add_argument(
1058        "-d", "--device", type=str, default=None,
1059        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
1060        "By default the most performant available device will be selected."
1061    )
1062    parser.add_argument(
1063        "--patch_shape", type=int, nargs="*", default=(512, 512),
1064        help="The choice of patch shape for training Segment Anything Model. "
1065        "By default, a patch size of 512x512 is used."
1066    )
1067    parser.add_argument(
1068        "-m", "--model_type", type=str, default=None,
1069        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
1070    )
1071    parser.add_argument(
1072        "--checkpoint_path", type=str, default=None,
1073        help="Checkpoint from which the SAM model will be loaded for finetuning."
1074    )
1075    parser.add_argument(
1076        "-s", "--save_root", type=str, default=None,
1077        help="The directory where the trained models and corresponding logs will be stored. "
1078        "By default, there are stored in your current working directory."
1079    )
1080    parser.add_argument(
1081        "--trained_model_name", type=str, default="sam_model",
1082        help="The custom name of trained model sub-folder. Allows users to have several trained models "
1083        "under the same 'save_root'."
1084    )
1085    parser.add_argument(
1086        "--output_path", type=str, default=None,
1087        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
1088    )
1089    parser.add_argument(
1090        "--n_epochs", type=int, default=100,
1091        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
1092    )
1093    parser.add_argument(
1094        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
1095    )
1096    parser.add_argument(
1097        "--batch_size", type=int, default=1,
1098        help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1."
1099    )
1100    parser.add_argument(
1101        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
1102        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
1103        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
1104    )
1105
1106    args = parser.parse_args()
1107
1108    # 1. Get all necessary stuff for training.
1109    checkpoint_name = args.trained_model_name
1110    config = args.configuration
1111    model_type = args.model_type
1112    checkpoint_path = args.checkpoint_path
1113    batch_size = args.batch_size
1114    patch_shape = args.patch_shape
1115    epochs = args.n_epochs
1116    num_workers = args.num_workers
1117    device = args.device
1118    save_root = args.save_root
1119    output_path = args.output_path
1120    with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder)
1121
1122    # Get image paths and corresponding keys.
1123    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
1124    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
1125
1126    # 2. Prepare the dataloaders.
1127
1128    # If the user wants to preprocess the inputs, we allow the possibility to do so.
1129    _raw_transform = get_raw_transform(args.preprocess)
1130
1131    # Get the dataset with files for training.
1132    dataset = default_sam_dataset(
1133        raw_paths=train_images,
1134        raw_key=train_image_key,
1135        label_paths=train_gt,
1136        label_key=train_gt_key,
1137        patch_shape=patch_shape,
1138        with_segmentation_decoder=with_segmentation_decoder,
1139        raw_transform=_raw_transform,
1140        train_instance_segmentation_only=train_instance_segmentation_only,
1141    )
1142
1143    # If val images are not exclusively provided, we create a val split from the training data.
1144    if val_images is None:
1145        assert val_gt is None and val_image_key is None and val_gt_key is None
1146        # Use 10% of the dataset for validation - at least one image - for validation.
1147        n_val = max(1, int(0.1 * len(dataset)))
1148        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
1149
1150    else:  # If val images provided, we create a new dataset for it.
1151        train_dataset = dataset
1152        val_dataset = default_sam_dataset(
1153            raw_paths=val_images,
1154            raw_key=val_image_key,
1155            label_paths=val_gt,
1156            label_key=val_gt_key,
1157            patch_shape=patch_shape,
1158            with_segmentation_decoder=with_segmentation_decoder,
1159            train_instance_segmentation_only=train_instance_segmentation_only,
1160            raw_transform=_raw_transform,
1161        )
1162
1163    # Get the dataloaders from the datasets.
1164    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1165    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
1166
1167    # 3. Train the Segment Anything Model.
1168
1169    # Get a valid model and other necessary parameters for training.
1170    if model_type is not None and model_type not in available_models:
1171        raise ValueError(f"'{model_type}' is not a valid choice of model.")
1172    if config is not None and config not in available_configurations:
1173        raise ValueError(f"'{config}' is not a valid choice of configuration.")
1174
1175    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
1176        model_type = CONFIGURATIONS[config]["model_type"]
1177
1178    train_sam_for_configuration(
1179        name=checkpoint_name,
1180        configuration=config,
1181        model_type=model_type,
1182        train_loader=train_loader,
1183        val_loader=val_loader,
1184        n_epochs=epochs,
1185        checkpoint_path=checkpoint_path,
1186        with_segmentation_decoder=with_segmentation_decoder,
1187        freeze=None,  # TODO: Allow for PEFT.
1188        device=device,
1189        save_root=save_root,
1190        peft_kwargs=None,  # TODO: Allow for PEFT.
1191        train_instance_segmentation_only=train_instance_segmentation_only,
1192    )
1193
1194    # 4. Export the model, if desired by the user
1195    if train_instance_segmentation_only and output_path:
1196        trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt")
1197        export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path)
1198        final_path = output_path
1199    else:
1200        final_path = _export_helper(
1201            save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader,
1202        )
1203
1204    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[str, os.PathLike, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt6.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, overwrite_training: bool = True, strict_decoder_loading: bool = True, **model_kwargs) -> None:
197def train_sam(
198    name: str,
199    model_type: str,
200    train_loader: DataLoader,
201    val_loader: DataLoader,
202    n_epochs: int = 100,
203    early_stopping: Optional[int] = 10,
204    n_objects_per_batch: Optional[int] = 25,
205    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
206    with_segmentation_decoder: bool = True,
207    freeze: Optional[List[str]] = None,
208    device: Optional[Union[str, torch.device]] = None,
209    lr: float = 1e-5,
210    n_sub_iteration: int = 8,
211    save_root: Optional[Union[str, os.PathLike]] = None,
212    mask_prob: float = 0.5,
213    n_iterations: Optional[int] = None,
214    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
215    scheduler_kwargs: Optional[Dict[str, Any]] = None,
216    save_every_kth_epoch: Optional[int] = None,
217    pbar_signals: Optional[QObject] = None,
218    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
219    peft_kwargs: Optional[Dict] = None,
220    ignore_warnings: bool = True,
221    verify_n_labels_in_loader: Optional[int] = 50,
222    box_distortion_factor: Optional[float] = 0.025,
223    overwrite_training: bool = True,
224    strict_decoder_loading: bool = True,
225    **model_kwargs,
226) -> None:
227    """Run training for a SAM model.
228
229    Args:
230        name: The name of the model to be trained. The checkpoint and logs will have this name.
231        model_type: The type of the SAM model.
232        train_loader: The dataloader for training.
233        val_loader: The dataloader for validation.
234        n_epochs: The number of epochs to train for.
235        early_stopping: Enable early stopping after this number of epochs without improvement.
236            By default, the value is set to '10' epochs.
237        n_objects_per_batch: The number of objects per batch used to compute
238            the loss for interative segmentation. If None all objects will be used,
239            if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
240        checkpoint_path: Path to checkpoint for initializing the SAM model.
241        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
242            By default, trains with the additional instance segmentation decoder.
243        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
244            By default nothing is frozen and the full model is updated.
245        device: The device to use for training. By default, automatically chooses the best available device to train.
246        lr: The learning rate. By default, set to '1e-5'.
247        n_sub_iteration: The number of iterative prompts per training iteration.
248            By default, the number of iterations is set to '8'.
249        save_root: Optional root directory for saving the checkpoints and logs.
250            If not given the current working directory is used.
251        mask_prob: The probability for using a mask as input in a given training sub-iteration.
252            By default, set to '0.5'.
253        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
254        scheduler_class: The learning rate scheduler to update the learning rate.
255            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
256        scheduler_kwargs: The learning rate scheduler parameters.
257            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
258        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
259        pbar_signals: Controls for napari progress bar.
260        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
261        peft_kwargs: Keyword arguments for the PEFT wrapper class.
262        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
263        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
264            By default, 50 batches of labels are verified from the dataloaders.
265        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
266            By default, the distortion factor is set to '0.025'.
267        overwrite_training: Whether to overwrite the trained model stored at the same location.
268            By default, overwrites the trained model at each run.
269            If set to 'False', it will avoid retraining the model if the previous run was completed.
270        strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
271            exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
272            channels if they were pre-trained for a different task. If set to False, decoders with a different
273            output dimension can be loaded; the output channels will be re-initialized.
274        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
275    """
276    with _filter_warnings(ignore_warnings):
277
278        t_start = time.time()
279        if verify_n_labels_in_loader is not None:
280            _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
281            _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
282
283        device = get_device(device)
284        # Get the trainable segment anything model.
285        model, state = get_trainable_sam_model(
286            model_type=model_type,
287            device=device,
288            freeze=freeze,
289            checkpoint_path=checkpoint_path,
290            return_state=True,
291            peft_kwargs=peft_kwargs,
292            **model_kwargs
293        )
294
295        # This class creates all the training data for a batch (inputs, prompts and labels).
296        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
297
298        # Create the UNETR decoder (if train with it) and the optimizer.
299        if with_segmentation_decoder:
300
301            # Get the UNETR.
302            unetr = get_unetr(
303                image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
304                device=device, flexible_load_checkpoint=not strict_decoder_loading,
305            )
306
307            # Get the parameters for SAM and the decoder from UNETR.
308            joint_model_params = [params for params in model.parameters()]  # sam parameters
309            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
310                if not param_name.startswith("encoder"):
311                    joint_model_params.append(params)
312
313            model_params = joint_model_params
314        else:
315            model_params = model.parameters()
316
317        optimizer, scheduler = _get_optimizer_and_scheduler(
318            model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs
319        )
320
321        # The trainer which performs training and validation.
322        if with_segmentation_decoder:
323            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
324            trainer = joint_trainers.JointSamTrainer(
325                name=name,
326                save_root=save_root,
327                train_loader=train_loader,
328                val_loader=val_loader,
329                model=model,
330                optimizer=optimizer,
331                device=device,
332                lr_scheduler=scheduler,
333                logger=joint_trainers.JointSamLogger,
334                log_image_interval=100,
335                mixed_precision=True,
336                convert_inputs=convert_inputs,
337                n_objects_per_batch=n_objects_per_batch,
338                n_sub_iteration=n_sub_iteration,
339                compile_model=False,
340                unetr=unetr,
341                instance_loss=instance_seg_loss,
342                instance_metric=instance_seg_loss,
343                early_stopping=early_stopping,
344                mask_prob=mask_prob,
345            )
346        else:
347            trainer = trainers.SamTrainer(
348                name=name,
349                train_loader=train_loader,
350                val_loader=val_loader,
351                model=model,
352                optimizer=optimizer,
353                device=device,
354                lr_scheduler=scheduler,
355                logger=trainers.SamLogger,
356                log_image_interval=100,
357                mixed_precision=True,
358                convert_inputs=convert_inputs,
359                n_objects_per_batch=n_objects_per_batch,
360                n_sub_iteration=n_sub_iteration,
361                compile_model=False,
362                early_stopping=early_stopping,
363                mask_prob=mask_prob,
364                save_root=save_root,
365            )
366
367        trainer_fit_params = _get_trainer_fit_params(
368            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
369        )
370        trainer.fit(**trainer_fit_params)
371
372        t_run = time.time() - t_start
373        hours = int(t_run // 3600)
374        minutes = int(t_run // 60)
375        seconds = int(round(t_run % 60, 0))
376        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output channels if they were pre-trained for a different task. If set to False, decoders with a different output dimension can be loaded; the output channels will be re-initialized.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def export_instance_segmentation_model( trained_model_path: Union[str, os.PathLike], output_path: Union[str, os.PathLike], model_type: str, initial_checkpoint_path: Union[str, os.PathLike, NoneType] = None) -> None:
379def export_instance_segmentation_model(
380    trained_model_path: Union[str, os.PathLike],
381    output_path: Union[str, os.PathLike],
382    model_type: str,
383    initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None,
384) -> None:
385    """Export a model trained for instance segmentation with `train_instance_segmentation`.
386
387    The exported model will be compatible with the micro_sam functions, CLI and napari plugin.
388    It should only be used for automatic segmentation and may not work well for interactive segmentation.
389
390    Args:
391        trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
392        output_path: The path where the exported model will be saved.
393        model_type: The model type.
394        initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
395    """
396    state = torch.load(trained_model_path, weights_only=False, map_location="cpu")
397    trained_state = state.get("model_state", state)
398
399    # Get the state of the encoder and instance segmentation decoder from the trained checkpoint.
400    encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")])
401    decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")])
402
403    # Load the original state of the model that was used as the basis of instance segmentation training.
404    predictor = get_sam_model(
405        model_type=model_type, checkpoint_path=initial_checkpoint_path, device="cpu",
406    )
407    model_state = OrderedDict(predictor.model.state_dict())
408
409    # Replace the image encoder weights with the trained ones.
410    # UNETR stores the image encoder as "encoder.*"; SAM uses "image_encoder.*".
411    for k in list(model_state.keys()):
412        if k.startswith("image_encoder."):
413            encoder_key = "encoder." + k[len("image_encoder."):]
414            model_state[k] = encoder_state[encoder_key]
415
416    save_state = {"model_state": model_state, "decoder_state": decoder_state}
417    torch.save(save_state, output_path)

Export a model trained for instance segmentation with train_instance_segmentation.

The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.

Arguments:
  • trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
  • output_path: The path where the exported model will be saved.
  • model_type: The model type.
  • initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
def train_instance_segmentation( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, loss: torch.nn.modules.module.Module = DiceBasedDistanceLoss( (foreground_loss): DiceLoss() (distance_loss): DiceLoss() ), metric: Optional[torch.nn.modules.module.Module] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, save_root: Union[str, os.PathLike, NoneType] = None, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt6.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, overwrite_training: bool = True, strict_decoder_loading: bool = True, **model_kwargs) -> None:
420def train_instance_segmentation(
421    name: str,
422    model_type: str,
423    train_loader: DataLoader,
424    val_loader: DataLoader,
425    n_epochs: int = 100,
426    early_stopping: Optional[int] = 10,
427    loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True),
428    metric: Optional[torch.nn.Module] = None,
429    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
430    freeze: Optional[List[str]] = None,
431    device: Optional[Union[str, torch.device]] = None,
432    lr: float = 1e-5,
433    save_root: Optional[Union[str, os.PathLike]] = None,
434    n_iterations: Optional[int] = None,
435    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
436    scheduler_kwargs: Optional[Dict[str, Any]] = None,
437    save_every_kth_epoch: Optional[int] = None,
438    pbar_signals: Optional[QObject] = None,
439    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
440    peft_kwargs: Optional[Dict] = None,
441    ignore_warnings: bool = True,
442    overwrite_training: bool = True,
443    strict_decoder_loading: bool = True,
444    **model_kwargs,
445) -> None:
446    """Train a UNETR for instance segmentation using the SAM encoder as backbone.
447
448    This setting corresponds to training a SAM model with an instance segmentation decoder,
449    without training the model parts for interactive segmentation,
450    i.e. without training the prompt encoder and mask decoder.
451
452    The checkpoint of the trained model, which will be saved in 'checkpoints/<name>',
453    will not be compatible with the micro_sam functionality.
454    You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it
455    in a format that is compatible with micro_sam functionality.
456    Note that the exported model should only be used for automatic segmentation via AIS.
457
458    Args:
459        name: The name of the model to be trained. The checkpoint and logs will have this name.
460        model_type: The type of the SAM model.
461        train_loader: The dataloader for training.
462        val_loader: The dataloader for validation.
463        n_epochs: The number of epochs to train for.
464        early_stopping: Enable early stopping after this number of epochs without improvement.
465            By default, the value is set to '10' epochs.
466        loss: The loss function to train the instance segmentation model.
467            By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
468        metric: The metric for the instance segmentation training.
469            By default the loss function is used as the metric.
470        checkpoint_path: Path to checkpoint for initializing the SAM model.
471        freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen.
472            By default nothing is frozen and the full model is updated.
473        device: The device to use for training. By default, automatically chooses the best available device to train.
474        lr: The learning rate. By default, set to '1e-5'.
475        save_root: Optional root directory for saving the checkpoints and logs.
476            If not given the current working directory is used.
477        n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given.
478        scheduler_class: The learning rate scheduler to update the learning rate.
479            By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used.
480        scheduler_kwargs: The learning rate scheduler parameters.
481            If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`.
482        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
483        pbar_signals: Controls for napari progress bar.
484        optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used.
485        peft_kwargs: Keyword arguments for the PEFT wrapper class.
486        ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
487        overwrite_training: Whether to overwrite the trained model stored at the same location.
488            By default, overwrites the trained model at each run.
489            If set to 'False', it will avoid retraining the model if the previous run was completed.
490        strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present,
491            exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output
492            channels if they were pre-trained for a different task. If set to False, decoders with a different
493            output dimension can be loaded; the output channels will be re-initialized.
494        model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`.
495    """
496
497    with _filter_warnings(ignore_warnings):
498        t_start = time.time()
499
500        sam_model, state = get_trainable_sam_model(
501            model_type=model_type,
502            device=device,
503            checkpoint_path=checkpoint_path,
504            return_state=True,
505            peft_kwargs=peft_kwargs,
506            freeze=freeze,
507            **model_kwargs
508        )
509        device = get_device(device)
510        model = get_unetr(
511            image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None),
512            device=device, flexible_load_checkpoint=not strict_decoder_loading,
513        )
514
515        optimizer, scheduler = _get_optimizer_and_scheduler(
516            model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs
517        )
518        trainer = torch_em.trainer.DefaultTrainer(
519            name=name,
520            model=model,
521            train_loader=train_loader,
522            val_loader=val_loader,
523            device=device,
524            mixed_precision=True,
525            log_image_interval=50,
526            compile_model=False,
527            save_root=save_root,
528            loss=loss,
529            metric=loss if metric is None else metric,
530            optimizer=optimizer,
531            lr_scheduler=scheduler,
532            early_stopping=early_stopping,
533        )
534
535        trainer_fit_params = _get_trainer_fit_params(
536            n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training
537        )
538        trainer.fit(**trainer_fit_params)
539
540        t_run = time.time() - t_start
541        hours = int(t_run // 3600)
542        minutes = int(t_run // 60)
543        seconds = int(round(t_run % 60, 0))
544        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Train a UNETR for instance segmentation using the SAM encoder as backbone.

This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.

The checkpoint of the trained model, which will be saved in 'checkpoints/', will not be compatible with the micro_sam functionality. You can call the function export_instance_segmentation_model with the path to the checkpoint to export it in a format that is compatible with micro_sam functionality. Note that the exported model should only be used for automatic segmentation via AIS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
  • loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
  • metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
  • device: The device to use for training. By default, automatically chooses the best available device to train.
  • lr: The learning rate. By default, set to '1e-5'.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed 'None', the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output channels if they were pre-trained for a different task. If set to False, decoders with a different output dimension can be loaded; the output channels will be re-initialized.
  • model_kwargs: Additional keyword arguments for the micro_sam.util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[numpy.ndarray, torch.Tensor]], List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[numpy.ndarray, torch.Tensor]], List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, train_instance_segmentation_only: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, rois: Union[slice, Tuple[slice, ...], NoneType] = None, is_multi_tensor: bool = True, **kwargs) -> torch.utils.data.dataset.Dataset:
625def default_sam_dataset(
626    raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
627    raw_key: Optional[str],
628    label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath],
629    label_key: Optional[str],
630    patch_shape: Tuple[int],
631    with_segmentation_decoder: bool,
632    with_channels: Optional[bool] = None,
633    train_instance_segmentation_only: bool = False,
634    sampler: Optional[Callable] = None,
635    raw_transform: Optional[Callable] = None,
636    n_samples: Optional[int] = None,
637    is_train: bool = True,
638    min_size: int = 25,
639    max_sampling_attempts: Optional[int] = None,
640    rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
641    is_multi_tensor: bool = True,
642    **kwargs,
643) -> Dataset:
644    """Create a PyTorch Dataset for training a SAM model.
645
646    Args:
647        raw_paths: The path(s) to the image data used for training.
648            Can either be multiple 2D images or volumetric data.
649            The data can also be passed as a list of numpy arrays or torch tensors.
650        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
651            or a glob pattern for selecting multiple files.
652            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
653        label_paths: The path(s) to the label data used for training.
654            Can either be multiple 2D images or volumetric data.
655            The data can also be passed as a list of numpy arrays or torch tensors.
656        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
657            or a glob pattern for selecting multiple files.
658            Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
659        patch_shape: The shape for training patches.
660        with_segmentation_decoder: Whether to train with additional segmentation decoder.
661        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
662        train_instance_segmentation_only: Set this argument to True in order to
663            pass the dataset to `train_instance_segmentation`. By default, set to 'False'.
664        sampler: A sampler to reject batches according to a given criterion.
665        raw_transform: Transformation applied to the image data.
666            If not given the data will be cast to 8bit.
667        n_samples: The number of samples for this dataset.
668        is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
669        min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
670        max_sampling_attempts: Number of sampling attempts to make from a dataset.
671        rois: The region of interest(s) for the data.
672        is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
673        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
674
675    Returns:
676        The segmentation dataset.
677    """
678
679    # Check if this dataset should be used for instance segmentation only training.
680    # If yes, we set return_instances to False, since the instance channel must not
681    # be passed for this training mode.
682    return_instances = True
683    if train_instance_segmentation_only:
684        if not with_segmentation_decoder:
685            raise ValueError(
686                "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True."
687            )
688        return_instances = False
689
690    # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample.
691    # This is necessary, because training for interactive segmentation does not work on 'empty' images.
692    # However, if we train only the automatic instance segmentation decoder, then this sampler is not required
693    # and we do not set a default sampler.
694    if sampler is None and not train_instance_segmentation_only:
695        sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size)
696
697    is_seg_dataset, with_channels = _determine_dataset_and_channels(
698        raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs
699    )
700    # Since `is_seg_dataset` is selected and sorted above, let's remove it out of running kwargs, if available.
701    kwargs.pop("is_seg_dataset", None)
702
703    # Set the data transformations.
704    if raw_transform is None:
705        raw_transform = require_8bit
706
707    # Prepare the label transform.
708    if with_segmentation_decoder:
709        default_label_transform = torch_em.transform.label.PerObjectDistanceTransform(
710            distances=True,
711            boundary_distances=True,
712            directed_distances=False,
713            foreground=True,
714            instances=return_instances,
715            min_size=min_size,
716        )
717    else:
718        default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
719
720    # Allow combining label transforms.
721    custom_label_transform = kwargs.pop("label_transform", None)
722    if custom_label_transform is None:
723        label_transform = default_label_transform
724    else:
725        label_transform = torch_em.transform.generic.Compose(
726            custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor
727        )
728
729    # Check the patch shape to add a singleton if required.
730    patch_shape = _update_patch_shape(
731        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
732    )
733
734    # Set a minimum number of samples per epoch.
735    if n_samples is None:
736        loader = torch_em.default_segmentation_loader(
737            raw_paths=raw_paths,
738            raw_key=raw_key,
739            label_paths=label_paths,
740            label_key=label_key,
741            batch_size=1,
742            patch_shape=patch_shape,
743            with_channels=with_channels,
744            ndim=2,
745            is_seg_dataset=is_seg_dataset,
746            raw_transform=raw_transform,
747            rois=rois,
748            **kwargs
749        )
750        n_samples = max(len(loader), 100 if is_train else 5)
751
752    dataset = torch_em.default_segmentation_dataset(
753        raw_paths=raw_paths,
754        raw_key=raw_key,
755        label_paths=label_paths,
756        label_key=label_key,
757        patch_shape=patch_shape,
758        raw_transform=raw_transform,
759        label_transform=label_transform,
760        with_channels=with_channels,
761        ndim=2,
762        sampler=sampler,
763        n_samples=n_samples,
764        is_seg_dataset=is_seg_dataset,
765        rois=rois,
766        **kwargs,
767    )
768
769    if max_sampling_attempts is not None:
770        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
771            for ds in dataset.datasets:
772                ds.max_sampling_attempts = max_sampling_attempts
773        else:
774            dataset.max_sampling_attempts = max_sampling_attempts
775
776    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • train_instance_segmentation_only: Set this argument to True in order to pass the dataset to train_instance_segmentation. By default, set to 'False'.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
  • min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • rois: The region of interest(s) for the data.
  • is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
779def default_sam_loader(**kwargs) -> DataLoader:
780    """Create a PyTorch DataLoader for training a SAM model.
781
782    Args:
783        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
784
785    Returns:
786        The DataLoader.
787    """
788    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
789
790    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
791    # which the users can provide to get their desired segmentation dataset.
792    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
793    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
794
795    ds = default_sam_dataset(**ds_kwargs)
796    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'gtx3080': {'model_type': 'vit_b', 'n_objects_per_batch': 5, 'peft_kwargs': {'attention_layers_to_update': [11], 'peft_module': <class 'micro_sam.models.peft_sam.ClassicalSurgery'>}}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

def train_sam_for_configuration( name: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, configuration: Optional[str] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, train_instance_segmentation_only: bool = False, model_type: Optional[str] = None, **kwargs) -> None:
838def train_sam_for_configuration(
839    name: str,
840    train_loader: DataLoader,
841    val_loader: DataLoader,
842    configuration: Optional[str] = None,
843    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
844    with_segmentation_decoder: bool = True,
845    train_instance_segmentation_only: bool = False,
846    model_type: Optional[str] = None,
847    **kwargs,
848) -> None:
849    """Run training for a SAM model with the configuration for a given hardware resource.
850
851    Selects the best training settings for the given configuration.
852    The available configurations are listed in `CONFIGURATIONS`.
853
854    Args:
855        name: The name of the model to be trained. The checkpoint and logs folder will have this name.
856        train_loader: The dataloader for training.
857        val_loader: The dataloader for validation.
858        configuration: The configuration (= name of hardware resource).
859            By default, it is automatically selected for the best VRAM combination.
860        checkpoint_path: Path to checkpoint for initializing the SAM model.
861        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
862            By default, trains with the additional instance segmentation decoder.
863        train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
864            using the training implementation `train_instance_segmentation`. By default, `train_sam` is used.
865        model_type: Over-ride the default model type.
866            This can be used to use one of the micro_sam models as starting point
867            instead of a default sam model.
868        kwargs: Additional keyword parameters that will be passed to `train_sam`.
869    """
870    if configuration is None:  # Automatically choose based on available VRAM combination.
871        configuration = _find_best_configuration()
872
873    if configuration in CONFIGURATIONS:
874        train_kwargs = CONFIGURATIONS[configuration]
875    else:
876        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
877
878    if model_type is None:
879        model_type = train_kwargs.pop("model_type")
880    else:
881        expected_model_type = train_kwargs.pop("model_type")
882        if model_type[:5] != expected_model_type:
883            warnings.warn("You have specified a different model type.")
884
885    train_kwargs.update(**kwargs)
886    if train_instance_segmentation_only:
887        instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs)
888        model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs)
889        instance_seg_kwargs.update(**model_kwargs)
890
891        train_instance_segmentation(
892            name=name,
893            train_loader=train_loader,
894            val_loader=val_loader,
895            checkpoint_path=checkpoint_path,
896            model_type=model_type,
897            **instance_seg_kwargs,
898        )
899    else:
900        train_sam(
901            name=name,
902            train_loader=train_loader,
903            val_loader=val_loader,
904            checkpoint_path=checkpoint_path,
905            with_segmentation_decoder=with_segmentation_decoder,
906            model_type=model_type,
907            **train_kwargs
908        )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs folder will have this name.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
  • train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation using the training implementation train_instance_segmentation. By default, train_sam is used.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.