micro_sam.training.training

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  8
  9import imageio.v3 as imageio
 10
 11import torch
 12from torch.optim import Optimizer
 13from torch.optim.lr_scheduler import _LRScheduler
 14from torch.utils.data import DataLoader, Dataset
 15
 16import torch_em
 17from torch_em.data.datasets.util import split_kwargs
 18
 19from elf.io import open_file
 20
 21try:
 22    from qtpy.QtCore import QObject
 23except Exception:
 24    QObject = Any
 25
 26from ..util import get_device
 27from . import sam_trainer as trainers
 28from ..instance_segmentation import get_unetr
 29from . import joint_sam_trainer as joint_trainers
 30from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit
 31
 32
 33FilePath = Union[str, os.PathLike]
 34
 35
 36def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 37    x, _ = next(iter(loader))
 38
 39    # Raw data: check that we have 1 or 3 channels.
 40    n_channels = x.shape[1]
 41    if n_channels not in (1, 3):
 42        raise ValueError(
 43            "Invalid number of channels for the input data from the data loader. "
 44            f"Expect 1 or 3 channels, got {n_channels}."
 45        )
 46
 47    # Raw data: check that it is between [0, 255]
 48    minval, maxval = x.min(), x.max()
 49    if minval < 0 or minval > 255:
 50        raise ValueError(
 51            "Invalid data range for the input data from the data loader. "
 52            f"The input has to be in range [0, 255], but got minimum value {minval}."
 53        )
 54    if maxval < 1 or maxval > 255:
 55        raise ValueError(
 56            "Invalid data range for the input data from the data loader. "
 57            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 58        )
 59
 60    # Target data: the check depends on whether we train with or without decoder.
 61    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 62
 63    def _check_instance_channel(instance_channel):
 64        unique_vals = torch.unique(instance_channel)
 65        if (unique_vals < 0).any():
 66            raise ValueError(
 67                "The target channel with the instance segmentation must not have negative values."
 68            )
 69        if len(unique_vals) == 1:
 70            raise ValueError(
 71                "The target channel with the instance segmentation must have at least one instance."
 72            )
 73        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 74            raise ValueError(
 75                "All values in the target channel with the instance segmentation must be integer."
 76            )
 77
 78    counter = 0
 79    name = "" if name is None else f"'{name}'"
 80    for x, y in tqdm(
 81        loader,
 82        desc=f"Verifying labels in {name} dataloader",
 83        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 84    ):
 85        n_channels_y = y.shape[1]
 86        if with_segmentation_decoder:
 87            if n_channels_y != 4:
 88                raise ValueError(
 89                    "Invalid number of channels in the target data from the data loader. "
 90                    "Expect 4 channel for training with an instance segmentation decoder, "
 91                    f"but got {n_channels_y} channels."
 92                )
 93            # Check instance channel per sample in a batch
 94            for per_y_sample in y:
 95                _check_instance_channel(per_y_sample[0])
 96
 97            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 98            if targets_min < 0 or targets_min > 1:
 99                raise ValueError(
100                    "Invalid value range in the target data from the value loader. "
101                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
102                    f"to be in range [0.0, 1.0], but got min {targets_min}"
103                )
104            if targets_max < 0 or targets_max > 1:
105                raise ValueError(
106                    "Invalid value range in the target data from the value loader. "
107                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
108                    f"to be in range [0.0, 1.0], but got max {targets_max}"
109                )
110
111        else:
112            if n_channels_y != 1:
113                raise ValueError(
114                    "Invalid number of channels in the target data from the data loader. "
115                    "Expect 1 channel for training without an instance segmentation decoder,"
116                    f"but got {n_channels_y} channels."
117                )
118            # Check instance channel per sample in a batch
119            for per_y_sample in y:
120                _check_instance_channel(per_y_sample)
121
122        counter += 1
123        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
124            break
125
126
127# Make the progress bar callbacks compatible with a tqdm progress bar interface.
128class _ProgressBarWrapper:
129    def __init__(self, signals):
130        self._signals = signals
131        self._total = None
132
133    @property
134    def total(self):
135        return self._total
136
137    @total.setter
138    def total(self, value):
139        self._signals.pbar_total.emit(value)
140        self._total = value
141
142    def update(self, steps):
143        self._signals.pbar_update.emit(steps)
144
145    def set_description(self, desc, **kwargs):
146        self._signals.pbar_description.emit(desc)
147
148
149@contextmanager
150def _filter_warnings(ignore_warnings):
151    if ignore_warnings:
152        with warnings.catch_warnings():
153            warnings.simplefilter("ignore")
154            yield
155    else:
156        with nullcontext():
157            yield
158
159
160def _count_parameters(model_parameters):
161    params = sum(p.numel() for p in model_parameters if p.requires_grad)
162    params = params / 1e6
163    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
164
165
166def train_sam(
167    name: str,
168    model_type: str,
169    train_loader: DataLoader,
170    val_loader: DataLoader,
171    n_epochs: int = 100,
172    early_stopping: Optional[int] = 10,
173    n_objects_per_batch: Optional[int] = 25,
174    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
175    with_segmentation_decoder: bool = True,
176    freeze: Optional[List[str]] = None,
177    device: Optional[Union[str, torch.device]] = None,
178    lr: float = 1e-5,
179    n_sub_iteration: int = 8,
180    save_root: Optional[Union[str, os.PathLike]] = None,
181    mask_prob: float = 0.5,
182    n_iterations: Optional[int] = None,
183    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
184    scheduler_kwargs: Optional[Dict[str, Any]] = None,
185    save_every_kth_epoch: Optional[int] = None,
186    pbar_signals: Optional[QObject] = None,
187    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
188    peft_kwargs: Optional[Dict] = None,
189    ignore_warnings: bool = True,
190    verify_n_labels_in_loader: Optional[int] = 50,
191    **model_kwargs,
192) -> None:
193    """Run training for a SAM model.
194
195    Args:
196        name: The name of the model to be trained.
197            The checkpoint and logs wil have this name.
198        model_type: The type of the SAM model.
199        train_loader: The dataloader for training.
200        val_loader: The dataloader for validation.
201        n_epochs: The number of epochs to train for.
202        early_stopping: Enable early stopping after this number of epochs
203            without improvement.
204        n_objects_per_batch: The number of objects per batch used to compute
205            the loss for interative segmentation. If None all objects will be used,
206            if given objects will be randomly sub-sampled.
207        checkpoint_path: Path to checkpoint for initializing the SAM model.
208        with_segmentation_decoder: Whether to train additional UNETR decoder
209            for automatic instance segmentation.
210        freeze: Specify parts of the model that should be frozen, namely:
211            image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
230            By default, 50 batches of labels are verified from the dataloaders.
231        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
232        ignore_warnings: Whether to ignore raised warnings.
233    """
234    with _filter_warnings(ignore_warnings):
235
236        t_start = time.time()
237
238        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
239        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
240
241        device = get_device(device)
242        # Get the trainable segment anything model.
243        model, state = get_trainable_sam_model(
244            model_type=model_type,
245            device=device,
246            freeze=freeze,
247            checkpoint_path=checkpoint_path,
248            return_state=True,
249            peft_kwargs=peft_kwargs,
250            **model_kwargs
251        )
252
253        # This class creates all the training data for a batch (inputs, prompts and labels).
254        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
255
256        # Create the UNETR decoder (if train with it) and the optimizer.
257        if with_segmentation_decoder:
258
259            # Get the UNETR.
260            unetr = get_unetr(
261                image_encoder=model.sam.image_encoder,
262                decoder_state=state.get("decoder_state", None),
263                device=device,
264            )
265
266            # Get the parameters for SAM and the decoder from UNETR.
267            joint_model_params = [params for params in model.parameters()]  # sam parameters
268            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
269                if not param_name.startswith("encoder"):
270                    joint_model_params.append(params)
271
272            optimizer = optimizer_class(joint_model_params, lr=lr)
273
274        else:
275            optimizer = optimizer_class(model.parameters(), lr=lr)
276
277        if scheduler_kwargs is None:
278            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
279
280        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
281
282        # The trainer which performs training and validation.
283        if with_segmentation_decoder:
284            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
285            trainer = joint_trainers.JointSamTrainer(
286                name=name,
287                save_root=save_root,
288                train_loader=train_loader,
289                val_loader=val_loader,
290                model=model,
291                optimizer=optimizer,
292                device=device,
293                lr_scheduler=scheduler,
294                logger=joint_trainers.JointSamLogger,
295                log_image_interval=100,
296                mixed_precision=True,
297                convert_inputs=convert_inputs,
298                n_objects_per_batch=n_objects_per_batch,
299                n_sub_iteration=n_sub_iteration,
300                compile_model=False,
301                unetr=unetr,
302                instance_loss=instance_seg_loss,
303                instance_metric=instance_seg_loss,
304                early_stopping=early_stopping,
305                mask_prob=mask_prob,
306            )
307        else:
308            trainer = trainers.SamTrainer(
309                name=name,
310                train_loader=train_loader,
311                val_loader=val_loader,
312                model=model,
313                optimizer=optimizer,
314                device=device,
315                lr_scheduler=scheduler,
316                logger=trainers.SamLogger,
317                log_image_interval=100,
318                mixed_precision=True,
319                convert_inputs=convert_inputs,
320                n_objects_per_batch=n_objects_per_batch,
321                n_sub_iteration=n_sub_iteration,
322                compile_model=False,
323                early_stopping=early_stopping,
324                mask_prob=mask_prob,
325                save_root=save_root,
326            )
327
328        if n_iterations is None:
329            trainer_fit_params = {"epochs": n_epochs}
330        else:
331            trainer_fit_params = {"iterations": n_iterations}
332
333        if save_every_kth_epoch is not None:
334            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
335
336        if pbar_signals is not None:
337            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
338            trainer_fit_params["progress"] = progress_bar_wrapper
339
340        trainer.fit(**trainer_fit_params)
341
342        t_run = time.time() - t_start
343        hours = int(t_run // 3600)
344        minutes = int(t_run // 60)
345        seconds = int(round(t_run % 60, 0))
346        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
347
348
349def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
350    if isinstance(raw_paths, (str, os.PathLike)):
351        path = raw_paths
352    else:
353        path = raw_paths[0]
354    assert isinstance(path, (str, os.PathLike))
355
356    # Check the underlying data dimensionality.
357    if raw_key is None:  # If no key is given then we assume it's an image file.
358        ndim = imageio.imread(path).ndim
359    else:  # Otherwise we try to open the file from key.
360        try:  # First try to open it with elf.
361            with open_file(path, "r") as f:
362                ndim = f[raw_key].ndim
363        except ValueError:  # This may fail for images in a folder with different sizes.
364            # In that case we read one of the images.
365            image_path = glob(os.path.join(path, raw_key))[0]
366            ndim = imageio.imread(image_path).ndim
367
368    if ndim == 2:
369        assert len(patch_shape) == 2
370        return patch_shape
371    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
372        return (1,) + patch_shape
373    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
374        return (1,) + patch_shape
375    else:
376        return patch_shape
377
378
379def default_sam_dataset(
380    raw_paths: Union[List[FilePath], FilePath],
381    raw_key: Optional[str],
382    label_paths: Union[List[FilePath], FilePath],
383    label_key: Optional[str],
384    patch_shape: Tuple[int],
385    with_segmentation_decoder: bool,
386    with_channels: bool = False,
387    sampler: Optional[Callable] = None,
388    raw_transform: Optional[Callable] = None,
389    n_samples: Optional[int] = None,
390    is_train: bool = True,
391    min_size: int = 25,
392    max_sampling_attempts: Optional[int] = None,
393    **kwargs,
394) -> Dataset:
395    """Create a PyTorch Dataset for training a SAM model.
396
397    Args:
398        raw_paths: The path(s) to the image data used for training.
399            Can either be multiple 2D images or volumetric data.
400        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
401            or a glob pattern for selecting multiple files.
402        label_paths: The path(s) to the label data used for training.
403            Can either be multiple 2D images or volumetric data.
404        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
405            or a glob pattern for selecting multiple files.
406        patch_shape: The shape for training patches.
407        with_segmentation_decoder: Whether to train with additional segmentation decoder.
408        with_channels: Whether the image data has RGB channels.
409        sampler: A sampler to reject batches according to a given criterion.
410        raw_transform: Transformation applied to the image data.
411            If not given the data will be cast to 8bit.
412        n_samples: The number of samples for this dataset.
413        is_train: Whether this dataset is used for training or validation.
414        min_size: Minimal object size. Smaller objects will be filtered.
415        max_sampling_attempts: Number of sampling attempts to make from a dataset.
416        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
417
418    Returns:
419        The segmentation dataset.
420    """
421
422    # Set the data transformations.
423    if raw_transform is None:
424        raw_transform = require_8bit
425
426    if with_segmentation_decoder:
427        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
428            distances=True, boundary_distances=True, directed_distances=False,
429            foreground=True, instances=True, min_size=min_size,
430        )
431    else:
432        label_transform = torch_em.transform.label.MinSizeLabelTransform(
433            min_size=min_size
434        )
435
436    # Set a default sampler if none was passed.
437    if sampler is None:
438        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
439
440    # Check the patch shape to add a singleton if required.
441    patch_shape = _update_patch_shape(
442        patch_shape, raw_paths, raw_key, with_channels
443    )
444
445    # Set a minimum number of samples per epoch.
446    if n_samples is None:
447        loader = torch_em.default_segmentation_loader(
448            raw_paths=raw_paths,
449            raw_key=raw_key,
450            label_paths=label_paths,
451            label_key=label_key,
452            batch_size=1,
453            patch_shape=patch_shape,
454            ndim=2,
455            **kwargs
456        )
457        n_samples = max(len(loader), 100 if is_train else 5)
458
459    dataset = torch_em.default_segmentation_dataset(
460        raw_paths=raw_paths,
461        raw_key=raw_key,
462        label_paths=label_paths,
463        label_key=label_key,
464        patch_shape=patch_shape,
465        raw_transform=raw_transform,
466        label_transform=label_transform,
467        with_channels=with_channels,
468        ndim=2,
469        sampler=sampler,
470        n_samples=n_samples,
471        **kwargs,
472    )
473
474    if max_sampling_attempts is not None:
475        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
476            for ds in dataset.datasets:
477                ds.max_sampling_attempts = max_sampling_attempts
478        else:
479            dataset.max_sampling_attempts = max_sampling_attempts
480
481    return dataset
482
483
484def default_sam_loader(**kwargs) -> DataLoader:
485    """Create a PyTorch DataLoader for training a SAM model.
486
487    Args:
488        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
489
490    Returns:
491        The DataLoader.
492    """
493    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
494
495    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
496    # which the users can provide to get their desired segmentation dataset.
497    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
498    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
499
500    ds = default_sam_dataset(**ds_kwargs)
501    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
502
503
504CONFIGURATIONS = {
505    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
506    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
507    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
508    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
509    "V100": {"model_type": "vit_b"},
510    "A100": {"model_type": "vit_h"},
511}
512"""Best training configurations for given hardware resources.
513"""
514
515
516def train_sam_for_configuration(
517    name: str,
518    configuration: str,
519    train_loader: DataLoader,
520    val_loader: DataLoader,
521    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
522    with_segmentation_decoder: bool = True,
523    model_type: Optional[str] = None,
524    **kwargs,
525) -> None:
526    """Run training for a SAM model with the configuration for a given hardware resource.
527
528    Selects the best training settings for the given configuration.
529    The available configurations are listed in `CONFIGURATIONS`.
530
531    Args:
532        name: The name of the model to be trained.
533            The checkpoint and logs wil have this name.
534        configuration: The configuration (= name of hardware resource).
535        train_loader: The dataloader for training.
536        val_loader: The dataloader for validation.
537        checkpoint_path: Path to checkpoint for initializing the SAM model.
538        with_segmentation_decoder: Whether to train additional UNETR decoder
539            for automatic instance segmentation.
540        model_type: Over-ride the default model type.
541            This can be used to use one of the micro_sam models as starting point
542            instead of a default sam model.
543        kwargs: Additional keyword parameters that will be passed to `train_sam`.
544    """
545    if configuration in CONFIGURATIONS:
546        train_kwargs = CONFIGURATIONS[configuration]
547    else:
548        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
549
550    if model_type is None:
551        model_type = train_kwargs.pop("model_type")
552    else:
553        expected_model_type = train_kwargs.pop("model_type")
554        if model_type[:5] != expected_model_type:
555            warnings.warn("You have specified a different model type.")
556
557    train_kwargs.update(**kwargs)
558    train_sam(
559        name=name, train_loader=train_loader, val_loader=val_loader,
560        checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
561        model_type=model_type, **train_kwargs
562    )
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, **model_kwargs) -> None:
167def train_sam(
168    name: str,
169    model_type: str,
170    train_loader: DataLoader,
171    val_loader: DataLoader,
172    n_epochs: int = 100,
173    early_stopping: Optional[int] = 10,
174    n_objects_per_batch: Optional[int] = 25,
175    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
176    with_segmentation_decoder: bool = True,
177    freeze: Optional[List[str]] = None,
178    device: Optional[Union[str, torch.device]] = None,
179    lr: float = 1e-5,
180    n_sub_iteration: int = 8,
181    save_root: Optional[Union[str, os.PathLike]] = None,
182    mask_prob: float = 0.5,
183    n_iterations: Optional[int] = None,
184    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
185    scheduler_kwargs: Optional[Dict[str, Any]] = None,
186    save_every_kth_epoch: Optional[int] = None,
187    pbar_signals: Optional[QObject] = None,
188    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
189    peft_kwargs: Optional[Dict] = None,
190    ignore_warnings: bool = True,
191    verify_n_labels_in_loader: Optional[int] = 50,
192    **model_kwargs,
193) -> None:
194    """Run training for a SAM model.
195
196    Args:
197        name: The name of the model to be trained.
198            The checkpoint and logs wil have this name.
199        model_type: The type of the SAM model.
200        train_loader: The dataloader for training.
201        val_loader: The dataloader for validation.
202        n_epochs: The number of epochs to train for.
203        early_stopping: Enable early stopping after this number of epochs
204            without improvement.
205        n_objects_per_batch: The number of objects per batch used to compute
206            the loss for interative segmentation. If None all objects will be used,
207            if given objects will be randomly sub-sampled.
208        checkpoint_path: Path to checkpoint for initializing the SAM model.
209        with_segmentation_decoder: Whether to train additional UNETR decoder
210            for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely:
212            image_encoder, prompt_encoder and mask_decoder
213            By default nothing is frozen and the full model is updated.
214        device: The device to use for training.
215        lr: The learning rate.
216        n_sub_iteration: The number of iterative prompts per training iteration.
217        save_root: Optional root directory for saving the checkpoints and logs.
218            If not given the current working directory is used.
219        mask_prob: The probability for using a mask as input in a given training sub-iteration.
220        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
221        scheduler_class: The learning rate scheduler to update the learning rate.
222            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
223        scheduler_kwargs: The learning rate scheduler parameters.
224            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
225        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
226        pbar_signals: Controls for napari progress bar.
227        optimizer_class: The optimizer class.
228            By default, torch.optim.AdamW is used.
229        peft_kwargs: Keyword arguments for the PEFT wrapper class.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
233        ignore_warnings: Whether to ignore raised warnings.
234    """
235    with _filter_warnings(ignore_warnings):
236
237        t_start = time.time()
238
239        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
240        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
241
242        device = get_device(device)
243        # Get the trainable segment anything model.
244        model, state = get_trainable_sam_model(
245            model_type=model_type,
246            device=device,
247            freeze=freeze,
248            checkpoint_path=checkpoint_path,
249            return_state=True,
250            peft_kwargs=peft_kwargs,
251            **model_kwargs
252        )
253
254        # This class creates all the training data for a batch (inputs, prompts and labels).
255        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
256
257        # Create the UNETR decoder (if train with it) and the optimizer.
258        if with_segmentation_decoder:
259
260            # Get the UNETR.
261            unetr = get_unetr(
262                image_encoder=model.sam.image_encoder,
263                decoder_state=state.get("decoder_state", None),
264                device=device,
265            )
266
267            # Get the parameters for SAM and the decoder from UNETR.
268            joint_model_params = [params for params in model.parameters()]  # sam parameters
269            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
270                if not param_name.startswith("encoder"):
271                    joint_model_params.append(params)
272
273            optimizer = optimizer_class(joint_model_params, lr=lr)
274
275        else:
276            optimizer = optimizer_class(model.parameters(), lr=lr)
277
278        if scheduler_kwargs is None:
279            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
280
281        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
282
283        # The trainer which performs training and validation.
284        if with_segmentation_decoder:
285            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
286            trainer = joint_trainers.JointSamTrainer(
287                name=name,
288                save_root=save_root,
289                train_loader=train_loader,
290                val_loader=val_loader,
291                model=model,
292                optimizer=optimizer,
293                device=device,
294                lr_scheduler=scheduler,
295                logger=joint_trainers.JointSamLogger,
296                log_image_interval=100,
297                mixed_precision=True,
298                convert_inputs=convert_inputs,
299                n_objects_per_batch=n_objects_per_batch,
300                n_sub_iteration=n_sub_iteration,
301                compile_model=False,
302                unetr=unetr,
303                instance_loss=instance_seg_loss,
304                instance_metric=instance_seg_loss,
305                early_stopping=early_stopping,
306                mask_prob=mask_prob,
307            )
308        else:
309            trainer = trainers.SamTrainer(
310                name=name,
311                train_loader=train_loader,
312                val_loader=val_loader,
313                model=model,
314                optimizer=optimizer,
315                device=device,
316                lr_scheduler=scheduler,
317                logger=trainers.SamLogger,
318                log_image_interval=100,
319                mixed_precision=True,
320                convert_inputs=convert_inputs,
321                n_objects_per_batch=n_objects_per_batch,
322                n_sub_iteration=n_sub_iteration,
323                compile_model=False,
324                early_stopping=early_stopping,
325                mask_prob=mask_prob,
326                save_root=save_root,
327            )
328
329        if n_iterations is None:
330            trainer_fit_params = {"epochs": n_epochs}
331        else:
332            trainer_fit_params = {"iterations": n_iterations}
333
334        if save_every_kth_epoch is not None:
335            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
336
337        if pbar_signals is not None:
338            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
339            trainer_fit_params["progress"] = progress_bar_wrapper
340
341        trainer.fit(**trainer_fit_params)
342
343        t_run = time.time() - t_start
344        hours = int(t_run // 3600)
345        minutes = int(t_run // 60)
346        seconds = int(round(t_run % 60, 0))
347        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
  • ignore_warnings: Whether to ignore raised warnings.
def default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
380def default_sam_dataset(
381    raw_paths: Union[List[FilePath], FilePath],
382    raw_key: Optional[str],
383    label_paths: Union[List[FilePath], FilePath],
384    label_key: Optional[str],
385    patch_shape: Tuple[int],
386    with_segmentation_decoder: bool,
387    with_channels: bool = False,
388    sampler: Optional[Callable] = None,
389    raw_transform: Optional[Callable] = None,
390    n_samples: Optional[int] = None,
391    is_train: bool = True,
392    min_size: int = 25,
393    max_sampling_attempts: Optional[int] = None,
394    **kwargs,
395) -> Dataset:
396    """Create a PyTorch Dataset for training a SAM model.
397
398    Args:
399        raw_paths: The path(s) to the image data used for training.
400            Can either be multiple 2D images or volumetric data.
401        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
402            or a glob pattern for selecting multiple files.
403        label_paths: The path(s) to the label data used for training.
404            Can either be multiple 2D images or volumetric data.
405        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
406            or a glob pattern for selecting multiple files.
407        patch_shape: The shape for training patches.
408        with_segmentation_decoder: Whether to train with additional segmentation decoder.
409        with_channels: Whether the image data has RGB channels.
410        sampler: A sampler to reject batches according to a given criterion.
411        raw_transform: Transformation applied to the image data.
412            If not given the data will be cast to 8bit.
413        n_samples: The number of samples for this dataset.
414        is_train: Whether this dataset is used for training or validation.
415        min_size: Minimal object size. Smaller objects will be filtered.
416        max_sampling_attempts: Number of sampling attempts to make from a dataset.
417        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
418
419    Returns:
420        The segmentation dataset.
421    """
422
423    # Set the data transformations.
424    if raw_transform is None:
425        raw_transform = require_8bit
426
427    if with_segmentation_decoder:
428        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
429            distances=True, boundary_distances=True, directed_distances=False,
430            foreground=True, instances=True, min_size=min_size,
431        )
432    else:
433        label_transform = torch_em.transform.label.MinSizeLabelTransform(
434            min_size=min_size
435        )
436
437    # Set a default sampler if none was passed.
438    if sampler is None:
439        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
440
441    # Check the patch shape to add a singleton if required.
442    patch_shape = _update_patch_shape(
443        patch_shape, raw_paths, raw_key, with_channels
444    )
445
446    # Set a minimum number of samples per epoch.
447    if n_samples is None:
448        loader = torch_em.default_segmentation_loader(
449            raw_paths=raw_paths,
450            raw_key=raw_key,
451            label_paths=label_paths,
452            label_key=label_key,
453            batch_size=1,
454            patch_shape=patch_shape,
455            ndim=2,
456            **kwargs
457        )
458        n_samples = max(len(loader), 100 if is_train else 5)
459
460    dataset = torch_em.default_segmentation_dataset(
461        raw_paths=raw_paths,
462        raw_key=raw_key,
463        label_paths=label_paths,
464        label_key=label_key,
465        patch_shape=patch_shape,
466        raw_transform=raw_transform,
467        label_transform=label_transform,
468        with_channels=with_channels,
469        ndim=2,
470        sampler=sampler,
471        n_samples=n_samples,
472        **kwargs,
473    )
474
475    if max_sampling_attempts is not None:
476        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
477            for ds in dataset.datasets:
478                ds.max_sampling_attempts = max_sampling_attempts
479        else:
480            dataset.max_sampling_attempts = max_sampling_attempts
481
482    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has RGB channels.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
485def default_sam_loader(**kwargs) -> DataLoader:
486    """Create a PyTorch DataLoader for training a SAM model.
487
488    Args:
489        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
490
491    Returns:
492        The DataLoader.
493    """
494    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
495
496    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
497    # which the users can provide to get their desired segmentation dataset.
498    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
499    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
500
501    ds = default_sam_dataset(**ds_kwargs)
502    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

def train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
517def train_sam_for_configuration(
518    name: str,
519    configuration: str,
520    train_loader: DataLoader,
521    val_loader: DataLoader,
522    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
523    with_segmentation_decoder: bool = True,
524    model_type: Optional[str] = None,
525    **kwargs,
526) -> None:
527    """Run training for a SAM model with the configuration for a given hardware resource.
528
529    Selects the best training settings for the given configuration.
530    The available configurations are listed in `CONFIGURATIONS`.
531
532    Args:
533        name: The name of the model to be trained.
534            The checkpoint and logs wil have this name.
535        configuration: The configuration (= name of hardware resource).
536        train_loader: The dataloader for training.
537        val_loader: The dataloader for validation.
538        checkpoint_path: Path to checkpoint for initializing the SAM model.
539        with_segmentation_decoder: Whether to train additional UNETR decoder
540            for automatic instance segmentation.
541        model_type: Over-ride the default model type.
542            This can be used to use one of the micro_sam models as starting point
543            instead of a default sam model.
544        kwargs: Additional keyword parameters that will be passed to `train_sam`.
545    """
546    if configuration in CONFIGURATIONS:
547        train_kwargs = CONFIGURATIONS[configuration]
548    else:
549        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
550
551    if model_type is None:
552        model_type = train_kwargs.pop("model_type")
553    else:
554        expected_model_type = train_kwargs.pop("model_type")
555        if model_type[:5] != expected_model_type:
556            warnings.warn("You have specified a different model type.")
557
558    train_kwargs.update(**kwargs)
559    train_sam(
560        name=name, train_loader=train_loader, val_loader=val_loader,
561        checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
562        model_type=model_type, **train_kwargs
563    )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.