micro_sam.training.training

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  8
  9import imageio.v3 as imageio
 10
 11import torch
 12from torch.optim import Optimizer
 13from torch.optim.lr_scheduler import _LRScheduler
 14from torch.utils.data import DataLoader, Dataset
 15
 16import torch_em
 17from torch_em.data.datasets.util import split_kwargs
 18
 19from elf.io import open_file
 20
 21try:
 22    from qtpy.QtCore import QObject
 23except Exception:
 24    QObject = Any
 25
 26from ..util import get_device
 27from . import sam_trainer as trainers
 28from ..instance_segmentation import get_unetr
 29from . import joint_sam_trainer as joint_trainers
 30from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit
 31
 32
 33FilePath = Union[str, os.PathLike]
 34
 35
 36def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 37    x, _ = next(iter(loader))
 38
 39    # Raw data: check that we have 1 or 3 channels.
 40    n_channels = x.shape[1]
 41    if n_channels not in (1, 3):
 42        raise ValueError(
 43            "Invalid number of channels for the input data from the data loader. "
 44            f"Expect 1 or 3 channels, got {n_channels}."
 45        )
 46
 47    # Raw data: check that it is between [0, 255]
 48    minval, maxval = x.min(), x.max()
 49    if minval < 0 or minval > 255:
 50        raise ValueError(
 51            "Invalid data range for the input data from the data loader. "
 52            f"The input has to be in range [0, 255], but got minimum value {minval}."
 53        )
 54    if maxval < 1 or maxval > 255:
 55        raise ValueError(
 56            "Invalid data range for the input data from the data loader. "
 57            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 58        )
 59
 60    # Target data: the check depends on whether we train with or without decoder.
 61    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 62
 63    def _check_instance_channel(instance_channel):
 64        unique_vals = torch.unique(instance_channel)
 65        if (unique_vals < 0).any():
 66            raise ValueError(
 67                "The target channel with the instance segmentation must not have negative values."
 68            )
 69        if len(unique_vals) == 1:
 70            raise ValueError(
 71                "The target channel with the instance segmentation must have at least one instance."
 72            )
 73        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 74            raise ValueError(
 75                "All values in the target channel with the instance segmentation must be integer."
 76            )
 77
 78    counter = 0
 79    name = "" if name is None else f"'{name}'"
 80    for x, y in tqdm(
 81        loader,
 82        desc=f"Verifying labels in {name} dataloader",
 83        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 84    ):
 85        n_channels_y = y.shape[1]
 86        if with_segmentation_decoder:
 87            if n_channels_y != 4:
 88                raise ValueError(
 89                    "Invalid number of channels in the target data from the data loader. "
 90                    "Expect 4 channel for training with an instance segmentation decoder, "
 91                    f"but got {n_channels_y} channels."
 92                )
 93            # Check instance channel per sample in a batch
 94            for per_y_sample in y:
 95                _check_instance_channel(per_y_sample[0])
 96
 97            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 98            if targets_min < 0 or targets_min > 1:
 99                raise ValueError(
100                    "Invalid value range in the target data from the value loader. "
101                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
102                    f"to be in range [0.0, 1.0], but got min {targets_min}"
103                )
104            if targets_max < 0 or targets_max > 1:
105                raise ValueError(
106                    "Invalid value range in the target data from the value loader. "
107                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
108                    f"to be in range [0.0, 1.0], but got max {targets_max}"
109                )
110
111        else:
112            if n_channels_y != 1:
113                raise ValueError(
114                    "Invalid number of channels in the target data from the data loader. "
115                    "Expect 1 channel for training without an instance segmentation decoder,"
116                    f"but got {n_channels_y} channels."
117                )
118            # Check instance channel per sample in a batch
119            for per_y_sample in y:
120                _check_instance_channel(per_y_sample)
121
122        counter += 1
123        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
124            break
125
126
127# Make the progress bar callbacks compatible with a tqdm progress bar interface.
128class _ProgressBarWrapper:
129    def __init__(self, signals):
130        self._signals = signals
131        self._total = None
132
133    @property
134    def total(self):
135        return self._total
136
137    @total.setter
138    def total(self, value):
139        self._signals.pbar_total.emit(value)
140        self._total = value
141
142    def update(self, steps):
143        self._signals.pbar_update.emit(steps)
144
145    def set_description(self, desc, **kwargs):
146        self._signals.pbar_description.emit(desc)
147
148
149def _count_parameters(model_parameters):
150    params = sum(p.numel() for p in model_parameters if p.requires_grad)
151    params = params / 1e6
152    print(f"The number of trainable parameters for the provided model is {round(params, 2)}M")
153
154
155@contextmanager
156def _filter_warnings(ignore_warnings):
157    if ignore_warnings:
158        with warnings.catch_warnings():
159            warnings.simplefilter("ignore")
160            yield
161    else:
162        with nullcontext():
163            yield
164
165
166def train_sam(
167    name: str,
168    model_type: str,
169    train_loader: DataLoader,
170    val_loader: DataLoader,
171    n_epochs: int = 100,
172    early_stopping: Optional[int] = 10,
173    n_objects_per_batch: Optional[int] = 25,
174    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
175    with_segmentation_decoder: bool = True,
176    freeze: Optional[List[str]] = None,
177    device: Optional[Union[str, torch.device]] = None,
178    lr: float = 1e-5,
179    n_sub_iteration: int = 8,
180    save_root: Optional[Union[str, os.PathLike]] = None,
181    mask_prob: float = 0.5,
182    n_iterations: Optional[int] = None,
183    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
184    scheduler_kwargs: Optional[Dict[str, Any]] = None,
185    save_every_kth_epoch: Optional[int] = None,
186    pbar_signals: Optional[QObject] = None,
187    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
188    peft_kwargs: Optional[Dict] = None,
189    ignore_warnings: bool = True,
190    verify_n_labels_in_loader: Optional[int] = 50,
191    **model_kwargs,
192) -> None:
193    """Run training for a SAM model.
194
195    Args:
196        name: The name of the model to be trained.
197            The checkpoint and logs wil have this name.
198        model_type: The type of the SAM model.
199        train_loader: The dataloader for training.
200        val_loader: The dataloader for validation.
201        n_epochs: The number of epochs to train for.
202        early_stopping: Enable early stopping after this number of epochs
203            without improvement.
204        n_objects_per_batch: The number of objects per batch used to compute
205            the loss for interative segmentation. If None all objects will be used,
206            if given objects will be randomly sub-sampled.
207        checkpoint_path: Path to checkpoint for initializing the SAM model.
208        with_segmentation_decoder: Whether to train additional UNETR decoder
209            for automatic instance segmentation.
210        freeze: Specify parts of the model that should be frozen, namely:
211            image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
230            By default, 50 batches of labels are verified from the dataloaders.
231        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
232        ignore_warnings: Whether to ignore raised warnings.
233    """
234    with _filter_warnings(ignore_warnings):
235
236        t_start = time.time()
237
238        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
239        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
240
241        device = get_device(device)
242        # Get the trainable segment anything model.
243        model, state = get_trainable_sam_model(
244            model_type=model_type,
245            device=device,
246            freeze=freeze,
247            checkpoint_path=checkpoint_path,
248            return_state=True,
249            peft_kwargs=peft_kwargs,
250            **model_kwargs
251        )
252        # This class creates all the training data for a batch (inputs, prompts and labels).
253        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
254
255        # Create the UNETR decoder (if train with it) and the optimizer.
256        if with_segmentation_decoder:
257
258            # Get the UNETR.
259            unetr = get_unetr(
260                image_encoder=model.sam.image_encoder,
261                decoder_state=state.get("decoder_state", None),
262                device=device,
263            )
264
265            # Get the parameters for SAM and the decoder from UNETR.
266            joint_model_params = [params for params in model.parameters()]  # sam parameters
267            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
268                if not param_name.startswith("encoder"):
269                    joint_model_params.append(params)
270
271            optimizer = optimizer_class(joint_model_params, lr=lr)
272
273        else:
274            optimizer = optimizer_class(model.parameters(), lr=lr)
275
276        if scheduler_kwargs is None:
277            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
278
279        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
280
281        # The trainer which performs training and validation.
282        if with_segmentation_decoder:
283            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
284            trainer = joint_trainers.JointSamTrainer(
285                name=name,
286                save_root=save_root,
287                train_loader=train_loader,
288                val_loader=val_loader,
289                model=model,
290                optimizer=optimizer,
291                device=device,
292                lr_scheduler=scheduler,
293                logger=joint_trainers.JointSamLogger,
294                log_image_interval=100,
295                mixed_precision=True,
296                convert_inputs=convert_inputs,
297                n_objects_per_batch=n_objects_per_batch,
298                n_sub_iteration=n_sub_iteration,
299                compile_model=False,
300                unetr=unetr,
301                instance_loss=instance_seg_loss,
302                instance_metric=instance_seg_loss,
303                early_stopping=early_stopping,
304                mask_prob=mask_prob,
305            )
306        else:
307            trainer = trainers.SamTrainer(
308                name=name,
309                train_loader=train_loader,
310                val_loader=val_loader,
311                model=model,
312                optimizer=optimizer,
313                device=device,
314                lr_scheduler=scheduler,
315                logger=trainers.SamLogger,
316                log_image_interval=100,
317                mixed_precision=True,
318                convert_inputs=convert_inputs,
319                n_objects_per_batch=n_objects_per_batch,
320                n_sub_iteration=n_sub_iteration,
321                compile_model=False,
322                early_stopping=early_stopping,
323                mask_prob=mask_prob,
324                save_root=save_root,
325            )
326
327        if n_iterations is None:
328            trainer_fit_params = {"epochs": n_epochs}
329        else:
330            trainer_fit_params = {"iterations": n_iterations}
331
332        if save_every_kth_epoch is not None:
333            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
334
335        if pbar_signals is not None:
336            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
337            trainer_fit_params["progress"] = progress_bar_wrapper
338
339        trainer.fit(**trainer_fit_params)
340
341        t_run = time.time() - t_start
342        hours = int(t_run // 3600)
343        minutes = int(t_run // 60)
344        seconds = int(round(t_run % 60, 0))
345        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
346
347
348def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
349    if isinstance(raw_paths, (str, os.PathLike)):
350        path = raw_paths
351    else:
352        path = raw_paths[0]
353    assert isinstance(path, (str, os.PathLike))
354
355    # Check the underlying data dimensionality.
356    if raw_key is None:  # If no key is given then we assume it's an image file.
357        ndim = imageio.imread(path).ndim
358    else:  # Otherwise we try to open the file from key.
359        try:  # First try to open it with elf.
360            with open_file(path, "r") as f:
361                ndim = f[raw_key].ndim
362        except ValueError:  # This may fail for images in a folder with different sizes.
363            # In that case we read one of the images.
364            image_path = glob(os.path.join(path, raw_key))[0]
365            ndim = imageio.imread(image_path).ndim
366
367    if ndim == 2:
368        assert len(patch_shape) == 2
369        return patch_shape
370    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
371        return (1,) + patch_shape
372    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
373        return (1,) + patch_shape
374    else:
375        return patch_shape
376
377
378def default_sam_dataset(
379    raw_paths: Union[List[FilePath], FilePath],
380    raw_key: Optional[str],
381    label_paths: Union[List[FilePath], FilePath],
382    label_key: Optional[str],
383    patch_shape: Tuple[int],
384    with_segmentation_decoder: bool,
385    with_channels: bool = False,
386    sampler: Optional[Callable] = None,
387    raw_transform: Optional[Callable] = None,
388    n_samples: Optional[int] = None,
389    is_train: bool = True,
390    min_size: int = 25,
391    max_sampling_attempts: Optional[int] = None,
392    is_seg_dataset: Optional[bool] = None,
393    **kwargs,
394) -> Dataset:
395    """Create a PyTorch Dataset for training a SAM model.
396
397    Args:
398        raw_paths: The path(s) to the image data used for training.
399            Can either be multiple 2D images or volumetric data.
400        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
401            or a glob pattern for selecting multiple files.
402        label_paths: The path(s) to the label data used for training.
403            Can either be multiple 2D images or volumetric data.
404        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
405            or a glob pattern for selecting multiple files.
406        patch_shape: The shape for training patches.
407        with_segmentation_decoder: Whether to train with additional segmentation decoder.
408        with_channels: Whether the image data has RGB channels.
409        sampler: A sampler to reject batches according to a given criterion.
410        raw_transform: Transformation applied to the image data.
411            If not given the data will be cast to 8bit.
412        n_samples: The number of samples for this dataset.
413        is_train: Whether this dataset is used for training or validation.
414        min_size: Minimal object size. Smaller objects will be filtered.
415        max_sampling_attempts: Number of sampling attempts to make from a dataset.
416        is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
417            or 'from torch_em.data import ImageCollectionDataset'
418
419    Returns:
420        The dataset.
421    """
422
423    # Set the data transformations.
424    if raw_transform is None:
425        raw_transform = require_8bit
426
427    if with_segmentation_decoder:
428        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
429            distances=True, boundary_distances=True, directed_distances=False,
430            foreground=True, instances=True, min_size=min_size,
431        )
432    else:
433        label_transform = torch_em.transform.label.MinSizeLabelTransform(
434            min_size=min_size
435        )
436
437    # Set a default sampler if none was passed.
438    if sampler is None:
439        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
440
441    # Check the patch shape to add a singleton if required.
442    patch_shape = _update_patch_shape(
443        patch_shape, raw_paths, raw_key, with_channels
444    )
445
446    # Set a minimum number of samples per epoch.
447    if n_samples is None:
448        loader = torch_em.default_segmentation_loader(
449            raw_paths, raw_key, label_paths, label_key, batch_size=1,
450            patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
451        )
452        n_samples = max(len(loader), 100 if is_train else 5)
453
454    dataset = torch_em.default_segmentation_dataset(
455        raw_paths, raw_key, label_paths, label_key,
456        patch_shape=patch_shape,
457        raw_transform=raw_transform, label_transform=label_transform,
458        with_channels=with_channels, ndim=2,
459        sampler=sampler, n_samples=n_samples,
460        is_seg_dataset=is_seg_dataset,
461        **kwargs,
462    )
463
464    if max_sampling_attempts is not None:
465        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
466            for ds in dataset.datasets:
467                ds.max_sampling_attempts = max_sampling_attempts
468        else:
469            dataset.max_sampling_attempts = max_sampling_attempts
470
471    return dataset
472
473
474def default_sam_loader(**kwargs) -> DataLoader:
475    ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs)
476    ds = default_sam_dataset(**ds_kwargs)
477    loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
478    return loader
479
480
481CONFIGURATIONS = {
482    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
483    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
484    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
485    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
486    "V100": {"model_type": "vit_b"},
487    "A100": {"model_type": "vit_h"},
488}
489"""Best training configurations for given hardware resources.
490"""
491
492
493def train_sam_for_configuration(
494    name: str,
495    configuration: str,
496    train_loader: DataLoader,
497    val_loader: DataLoader,
498    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
499    with_segmentation_decoder: bool = True,
500    model_type: Optional[str] = None,
501    **kwargs,
502) -> None:
503    """Run training for a SAM model with the configuration for a given hardware resource.
504
505    Selects the best training settings for the given configuration.
506    The available configurations are listed in `CONFIGURATIONS`.
507
508    Args:
509        name: The name of the model to be trained.
510            The checkpoint and logs wil have this name.
511        configuration: The configuration (= name of hardware resource).
512        train_loader: The dataloader for training.
513        val_loader: The dataloader for validation.
514        checkpoint_path: Path to checkpoint for initializing the SAM model.
515        with_segmentation_decoder: Whether to train additional UNETR decoder
516            for automatic instance segmentation.
517        model_type: Over-ride the default model type.
518            This can be used to use one of the micro_sam models as starting point
519            instead of a default sam model.
520        kwargs: Additional keyword parameterts that will be passed to `train_sam`.
521    """
522    if configuration in CONFIGURATIONS:
523        train_kwargs = CONFIGURATIONS[configuration]
524    else:
525        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
526
527    if model_type is None:
528        model_type = train_kwargs.pop("model_type")
529    else:
530        expected_model_type = train_kwargs.pop("model_type")
531        if model_type[:5] != expected_model_type:
532            warnings.warn("You have specified a different model type.")
533
534    train_kwargs.update(**kwargs)
535    train_sam(
536        name=name, train_loader=train_loader, val_loader=val_loader,
537        checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
538        model_type=model_type, **train_kwargs
539    )
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, **model_kwargs) -> None:
167def train_sam(
168    name: str,
169    model_type: str,
170    train_loader: DataLoader,
171    val_loader: DataLoader,
172    n_epochs: int = 100,
173    early_stopping: Optional[int] = 10,
174    n_objects_per_batch: Optional[int] = 25,
175    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
176    with_segmentation_decoder: bool = True,
177    freeze: Optional[List[str]] = None,
178    device: Optional[Union[str, torch.device]] = None,
179    lr: float = 1e-5,
180    n_sub_iteration: int = 8,
181    save_root: Optional[Union[str, os.PathLike]] = None,
182    mask_prob: float = 0.5,
183    n_iterations: Optional[int] = None,
184    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
185    scheduler_kwargs: Optional[Dict[str, Any]] = None,
186    save_every_kth_epoch: Optional[int] = None,
187    pbar_signals: Optional[QObject] = None,
188    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
189    peft_kwargs: Optional[Dict] = None,
190    ignore_warnings: bool = True,
191    verify_n_labels_in_loader: Optional[int] = 50,
192    **model_kwargs,
193) -> None:
194    """Run training for a SAM model.
195
196    Args:
197        name: The name of the model to be trained.
198            The checkpoint and logs wil have this name.
199        model_type: The type of the SAM model.
200        train_loader: The dataloader for training.
201        val_loader: The dataloader for validation.
202        n_epochs: The number of epochs to train for.
203        early_stopping: Enable early stopping after this number of epochs
204            without improvement.
205        n_objects_per_batch: The number of objects per batch used to compute
206            the loss for interative segmentation. If None all objects will be used,
207            if given objects will be randomly sub-sampled.
208        checkpoint_path: Path to checkpoint for initializing the SAM model.
209        with_segmentation_decoder: Whether to train additional UNETR decoder
210            for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely:
212            image_encoder, prompt_encoder and mask_decoder
213            By default nothing is frozen and the full model is updated.
214        device: The device to use for training.
215        lr: The learning rate.
216        n_sub_iteration: The number of iterative prompts per training iteration.
217        save_root: Optional root directory for saving the checkpoints and logs.
218            If not given the current working directory is used.
219        mask_prob: The probability for using a mask as input in a given training sub-iteration.
220        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
221        scheduler_class: The learning rate scheduler to update the learning rate.
222            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
223        scheduler_kwargs: The learning rate scheduler parameters.
224            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
225        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
226        pbar_signals: Controls for napari progress bar.
227        optimizer_class: The optimizer class.
228            By default, torch.optim.AdamW is used.
229        peft_kwargs: Keyword arguments for the PEFT wrapper class.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
233        ignore_warnings: Whether to ignore raised warnings.
234    """
235    with _filter_warnings(ignore_warnings):
236
237        t_start = time.time()
238
239        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
240        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
241
242        device = get_device(device)
243        # Get the trainable segment anything model.
244        model, state = get_trainable_sam_model(
245            model_type=model_type,
246            device=device,
247            freeze=freeze,
248            checkpoint_path=checkpoint_path,
249            return_state=True,
250            peft_kwargs=peft_kwargs,
251            **model_kwargs
252        )
253        # This class creates all the training data for a batch (inputs, prompts and labels).
254        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
255
256        # Create the UNETR decoder (if train with it) and the optimizer.
257        if with_segmentation_decoder:
258
259            # Get the UNETR.
260            unetr = get_unetr(
261                image_encoder=model.sam.image_encoder,
262                decoder_state=state.get("decoder_state", None),
263                device=device,
264            )
265
266            # Get the parameters for SAM and the decoder from UNETR.
267            joint_model_params = [params for params in model.parameters()]  # sam parameters
268            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
269                if not param_name.startswith("encoder"):
270                    joint_model_params.append(params)
271
272            optimizer = optimizer_class(joint_model_params, lr=lr)
273
274        else:
275            optimizer = optimizer_class(model.parameters(), lr=lr)
276
277        if scheduler_kwargs is None:
278            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
279
280        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
281
282        # The trainer which performs training and validation.
283        if with_segmentation_decoder:
284            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
285            trainer = joint_trainers.JointSamTrainer(
286                name=name,
287                save_root=save_root,
288                train_loader=train_loader,
289                val_loader=val_loader,
290                model=model,
291                optimizer=optimizer,
292                device=device,
293                lr_scheduler=scheduler,
294                logger=joint_trainers.JointSamLogger,
295                log_image_interval=100,
296                mixed_precision=True,
297                convert_inputs=convert_inputs,
298                n_objects_per_batch=n_objects_per_batch,
299                n_sub_iteration=n_sub_iteration,
300                compile_model=False,
301                unetr=unetr,
302                instance_loss=instance_seg_loss,
303                instance_metric=instance_seg_loss,
304                early_stopping=early_stopping,
305                mask_prob=mask_prob,
306            )
307        else:
308            trainer = trainers.SamTrainer(
309                name=name,
310                train_loader=train_loader,
311                val_loader=val_loader,
312                model=model,
313                optimizer=optimizer,
314                device=device,
315                lr_scheduler=scheduler,
316                logger=trainers.SamLogger,
317                log_image_interval=100,
318                mixed_precision=True,
319                convert_inputs=convert_inputs,
320                n_objects_per_batch=n_objects_per_batch,
321                n_sub_iteration=n_sub_iteration,
322                compile_model=False,
323                early_stopping=early_stopping,
324                mask_prob=mask_prob,
325                save_root=save_root,
326            )
327
328        if n_iterations is None:
329            trainer_fit_params = {"epochs": n_epochs}
330        else:
331            trainer_fit_params = {"iterations": n_iterations}
332
333        if save_every_kth_epoch is not None:
334            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
335
336        if pbar_signals is not None:
337            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
338            trainer_fit_params["progress"] = progress_bar_wrapper
339
340        trainer.fit(**trainer_fit_params)
341
342        t_run = time.time() - t_start
343        hours = int(t_run // 3600)
344        minutes = int(t_run // 60)
345        seconds = int(round(t_run % 60, 0))
346        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
  • ignore_warnings: Whether to ignore raised warnings.
def default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, is_seg_dataset: Optional[bool] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
379def default_sam_dataset(
380    raw_paths: Union[List[FilePath], FilePath],
381    raw_key: Optional[str],
382    label_paths: Union[List[FilePath], FilePath],
383    label_key: Optional[str],
384    patch_shape: Tuple[int],
385    with_segmentation_decoder: bool,
386    with_channels: bool = False,
387    sampler: Optional[Callable] = None,
388    raw_transform: Optional[Callable] = None,
389    n_samples: Optional[int] = None,
390    is_train: bool = True,
391    min_size: int = 25,
392    max_sampling_attempts: Optional[int] = None,
393    is_seg_dataset: Optional[bool] = None,
394    **kwargs,
395) -> Dataset:
396    """Create a PyTorch Dataset for training a SAM model.
397
398    Args:
399        raw_paths: The path(s) to the image data used for training.
400            Can either be multiple 2D images or volumetric data.
401        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
402            or a glob pattern for selecting multiple files.
403        label_paths: The path(s) to the label data used for training.
404            Can either be multiple 2D images or volumetric data.
405        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
406            or a glob pattern for selecting multiple files.
407        patch_shape: The shape for training patches.
408        with_segmentation_decoder: Whether to train with additional segmentation decoder.
409        with_channels: Whether the image data has RGB channels.
410        sampler: A sampler to reject batches according to a given criterion.
411        raw_transform: Transformation applied to the image data.
412            If not given the data will be cast to 8bit.
413        n_samples: The number of samples for this dataset.
414        is_train: Whether this dataset is used for training or validation.
415        min_size: Minimal object size. Smaller objects will be filtered.
416        max_sampling_attempts: Number of sampling attempts to make from a dataset.
417        is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
418            or 'from torch_em.data import ImageCollectionDataset'
419
420    Returns:
421        The dataset.
422    """
423
424    # Set the data transformations.
425    if raw_transform is None:
426        raw_transform = require_8bit
427
428    if with_segmentation_decoder:
429        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
430            distances=True, boundary_distances=True, directed_distances=False,
431            foreground=True, instances=True, min_size=min_size,
432        )
433    else:
434        label_transform = torch_em.transform.label.MinSizeLabelTransform(
435            min_size=min_size
436        )
437
438    # Set a default sampler if none was passed.
439    if sampler is None:
440        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
441
442    # Check the patch shape to add a singleton if required.
443    patch_shape = _update_patch_shape(
444        patch_shape, raw_paths, raw_key, with_channels
445    )
446
447    # Set a minimum number of samples per epoch.
448    if n_samples is None:
449        loader = torch_em.default_segmentation_loader(
450            raw_paths, raw_key, label_paths, label_key, batch_size=1,
451            patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
452        )
453        n_samples = max(len(loader), 100 if is_train else 5)
454
455    dataset = torch_em.default_segmentation_dataset(
456        raw_paths, raw_key, label_paths, label_key,
457        patch_shape=patch_shape,
458        raw_transform=raw_transform, label_transform=label_transform,
459        with_channels=with_channels, ndim=2,
460        sampler=sampler, n_samples=n_samples,
461        is_seg_dataset=is_seg_dataset,
462        **kwargs,
463    )
464
465    if max_sampling_attempts is not None:
466        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
467            for ds in dataset.datasets:
468                ds.max_sampling_attempts = max_sampling_attempts
469        else:
470            dataset.max_sampling_attempts = max_sampling_attempts
471
472    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has RGB channels.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' or 'from torch_em.data import ImageCollectionDataset'
Returns:

The dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
475def default_sam_loader(**kwargs) -> DataLoader:
476    ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs)
477    ds = default_sam_dataset(**ds_kwargs)
478    loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
479    return loader
CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

def train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
494def train_sam_for_configuration(
495    name: str,
496    configuration: str,
497    train_loader: DataLoader,
498    val_loader: DataLoader,
499    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
500    with_segmentation_decoder: bool = True,
501    model_type: Optional[str] = None,
502    **kwargs,
503) -> None:
504    """Run training for a SAM model with the configuration for a given hardware resource.
505
506    Selects the best training settings for the given configuration.
507    The available configurations are listed in `CONFIGURATIONS`.
508
509    Args:
510        name: The name of the model to be trained.
511            The checkpoint and logs wil have this name.
512        configuration: The configuration (= name of hardware resource).
513        train_loader: The dataloader for training.
514        val_loader: The dataloader for validation.
515        checkpoint_path: Path to checkpoint for initializing the SAM model.
516        with_segmentation_decoder: Whether to train additional UNETR decoder
517            for automatic instance segmentation.
518        model_type: Over-ride the default model type.
519            This can be used to use one of the micro_sam models as starting point
520            instead of a default sam model.
521        kwargs: Additional keyword parameterts that will be passed to `train_sam`.
522    """
523    if configuration in CONFIGURATIONS:
524        train_kwargs = CONFIGURATIONS[configuration]
525    else:
526        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
527
528    if model_type is None:
529        model_type = train_kwargs.pop("model_type")
530    else:
531        expected_model_type = train_kwargs.pop("model_type")
532        if model_type[:5] != expected_model_type:
533            warnings.warn("You have specified a different model type.")
534
535    train_kwargs.update(**kwargs)
536    train_sam(
537        name=name, train_loader=train_loader, val_loader=val_loader,
538        checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
539        model_type=model_type, **train_kwargs
540    )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameterts that will be passed to train_sam.