synapse_net.training.domain_adaptation

  1import os
  2from typing import Optional, Tuple
  3
  4import torch
  5import torch_em
  6import torch_em.self_training as self_training
  7
  8from .semisupervised_training import get_unsupervised_loader
  9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
 10
 11
 12def mean_teacher_adaptation(
 13    name: str,
 14    unsupervised_train_paths: Tuple[str],
 15    unsupervised_val_paths: Tuple[str],
 16    patch_shape: Tuple[int, int, int],
 17    save_root: Optional[str] = None,
 18    source_checkpoint: Optional[str] = None,
 19    supervised_train_paths: Optional[Tuple[str]] = None,
 20    supervised_val_paths: Optional[Tuple[str]] = None,
 21    confidence_threshold: float = 0.9,
 22    raw_key: str = "raw",
 23    raw_key_supervised: str = "raw",
 24    label_key: Optional[str] = None,
 25    batch_size: int = 1,
 26    lr: float = 1e-4,
 27    n_iterations: int = int(1e4),
 28    n_samples_train: Optional[int] = None,
 29    n_samples_val: Optional[int] = None,
 30    sampler: Optional[callable] = None,
 31) -> None:
 32    """Run domain adapation to transfer a network trained on a source domain for a supervised
 33    segmentation task to perform this task on a different target domain.
 34
 35    We support different domain adaptation settings:
 36    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 37     'supervised_val_paths' are not given.
 38    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 39      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 40
 41    Args:
 42        name: The name for the checkpoint to be trained.
 43        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 44            for the training data in the target domain.
 45            This training data is used for unsupervised learning, so it does not require labels.
 46        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 47            for the validation data in the target domain.
 48            This validation data is used for unsupervised learning, so it does not require labels.
 49        patch_shape: The patch shape used for a training example.
 50            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 51            e.g. 'patch_shape = [1, 512, 512]'.
 52        save_root: Folder where the checkpoint will be saved.
 53        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 54            This is used to initialize the teacher model.
 55            If the checkpoint is not given, then both student and teacher model are initialized
 56            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 57            be given in order to provide training data from the source domain.
 58        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 59            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 60        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 61            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 62        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 63            The label filtering is done based on the uncertainty of network predictions, and only
 64            the data with higher certainty than this threshold is used for training.
 65        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 66        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 67            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 68        batch_size: The batch size for training.
 69        lr: The initial learning rate.
 70        n_iterations: The number of iterations to train for.
 71        n_samples_train: The number of train samples per epoch. By default this will be estimated
 72            based on the patch_shape and size of the volumes used for training.
 73        n_samples_val: The number of val samples per epoch. By default this will be estimated
 74            based on the patch_shape and size of the volumes used for validation.
 75    """
 76    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 77    is_2d, _ = _determine_ndim(patch_shape)
 78
 79    if source_checkpoint is None:
 80        # training from scratch only makes sense if we have supervised training data
 81        # that's why we have the assertion here.
 82        assert supervised_train_paths is not None
 83        print("Mean teacher training from scratch (AdaMT)")
 84        if is_2d:
 85            model = get_2d_model(out_channels=2)
 86        else:
 87            model = get_3d_model(out_channels=2)
 88        reinit_teacher = True
 89    else:
 90        print("Mean teacehr training initialized from source model:", source_checkpoint)
 91        if os.path.isdir(source_checkpoint):
 92            model = torch_em.util.load_model(source_checkpoint)
 93        else:
 94            model = torch.load(source_checkpoint)
 95        reinit_teacher = False
 96
 97    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 98    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
 99
100    # self training functionality
101    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
102    loss = self_training.DefaultSelfTrainingLoss()
103    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
104
105    unsupervised_train_loader = get_unsupervised_loader(
106        unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
107    )
108    unsupervised_val_loader = get_unsupervised_loader(
109        unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
110    )
111
112    if supervised_train_paths is not None:
113        assert label_key is not None
114        supervised_train_loader = get_supervised_loader(
115            supervised_train_paths, raw_key_supervised, label_key,
116            patch_shape, batch_size, n_samples=n_samples_train,
117        )
118        supervised_val_loader = get_supervised_loader(
119            supervised_val_paths, raw_key_supervised, label_key,
120            patch_shape, batch_size, n_samples=n_samples_val,
121        )
122    else:
123        supervised_train_loader = None
124        supervised_val_loader = None
125
126    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
127    trainer = self_training.MeanTeacherTrainer(
128        name=name,
129        model=model,
130        optimizer=optimizer,
131        lr_scheduler=scheduler,
132        pseudo_labeler=pseudo_labeler,
133        unsupervised_loss=loss,
134        unsupervised_loss_and_metric=loss_and_metric,
135        supervised_train_loader=supervised_train_loader,
136        unsupervised_train_loader=unsupervised_train_loader,
137        supervised_val_loader=supervised_val_loader,
138        unsupervised_val_loader=unsupervised_val_loader,
139        supervised_loss=loss,
140        supervised_loss_and_metric=loss_and_metric,
141        logger=self_training.SelfTrainingTensorboardLogger,
142        mixed_precision=True,
143        log_image_interval=100,
144        compile_model=False,
145        device=device,
146        reinit_teacher=reinit_teacher,
147        save_root=save_root,
148        sampler=sampler,
149    )
150    trainer.fit(n_iterations)
def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, sampler: Optional[<built-in function callable>] = None) -> None:
 13def mean_teacher_adaptation(
 14    name: str,
 15    unsupervised_train_paths: Tuple[str],
 16    unsupervised_val_paths: Tuple[str],
 17    patch_shape: Tuple[int, int, int],
 18    save_root: Optional[str] = None,
 19    source_checkpoint: Optional[str] = None,
 20    supervised_train_paths: Optional[Tuple[str]] = None,
 21    supervised_val_paths: Optional[Tuple[str]] = None,
 22    confidence_threshold: float = 0.9,
 23    raw_key: str = "raw",
 24    raw_key_supervised: str = "raw",
 25    label_key: Optional[str] = None,
 26    batch_size: int = 1,
 27    lr: float = 1e-4,
 28    n_iterations: int = int(1e4),
 29    n_samples_train: Optional[int] = None,
 30    n_samples_val: Optional[int] = None,
 31    sampler: Optional[callable] = None,
 32) -> None:
 33    """Run domain adapation to transfer a network trained on a source domain for a supervised
 34    segmentation task to perform this task on a different target domain.
 35
 36    We support different domain adaptation settings:
 37    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 38     'supervised_val_paths' are not given.
 39    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 40      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 41
 42    Args:
 43        name: The name for the checkpoint to be trained.
 44        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 45            for the training data in the target domain.
 46            This training data is used for unsupervised learning, so it does not require labels.
 47        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 48            for the validation data in the target domain.
 49            This validation data is used for unsupervised learning, so it does not require labels.
 50        patch_shape: The patch shape used for a training example.
 51            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 52            e.g. 'patch_shape = [1, 512, 512]'.
 53        save_root: Folder where the checkpoint will be saved.
 54        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 55            This is used to initialize the teacher model.
 56            If the checkpoint is not given, then both student and teacher model are initialized
 57            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 58            be given in order to provide training data from the source domain.
 59        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 60            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 61        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 62            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 63        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 64            The label filtering is done based on the uncertainty of network predictions, and only
 65            the data with higher certainty than this threshold is used for training.
 66        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 67        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 68            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 69        batch_size: The batch size for training.
 70        lr: The initial learning rate.
 71        n_iterations: The number of iterations to train for.
 72        n_samples_train: The number of train samples per epoch. By default this will be estimated
 73            based on the patch_shape and size of the volumes used for training.
 74        n_samples_val: The number of val samples per epoch. By default this will be estimated
 75            based on the patch_shape and size of the volumes used for validation.
 76    """
 77    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 78    is_2d, _ = _determine_ndim(patch_shape)
 79
 80    if source_checkpoint is None:
 81        # training from scratch only makes sense if we have supervised training data
 82        # that's why we have the assertion here.
 83        assert supervised_train_paths is not None
 84        print("Mean teacher training from scratch (AdaMT)")
 85        if is_2d:
 86            model = get_2d_model(out_channels=2)
 87        else:
 88            model = get_3d_model(out_channels=2)
 89        reinit_teacher = True
 90    else:
 91        print("Mean teacehr training initialized from source model:", source_checkpoint)
 92        if os.path.isdir(source_checkpoint):
 93            model = torch_em.util.load_model(source_checkpoint)
 94        else:
 95            model = torch.load(source_checkpoint)
 96        reinit_teacher = False
 97
 98    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 99    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
100
101    # self training functionality
102    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
103    loss = self_training.DefaultSelfTrainingLoss()
104    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
105
106    unsupervised_train_loader = get_unsupervised_loader(
107        unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
108    )
109    unsupervised_val_loader = get_unsupervised_loader(
110        unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
111    )
112
113    if supervised_train_paths is not None:
114        assert label_key is not None
115        supervised_train_loader = get_supervised_loader(
116            supervised_train_paths, raw_key_supervised, label_key,
117            patch_shape, batch_size, n_samples=n_samples_train,
118        )
119        supervised_val_loader = get_supervised_loader(
120            supervised_val_paths, raw_key_supervised, label_key,
121            patch_shape, batch_size, n_samples=n_samples_val,
122        )
123    else:
124        supervised_train_loader = None
125        supervised_val_loader = None
126
127    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
128    trainer = self_training.MeanTeacherTrainer(
129        name=name,
130        model=model,
131        optimizer=optimizer,
132        lr_scheduler=scheduler,
133        pseudo_labeler=pseudo_labeler,
134        unsupervised_loss=loss,
135        unsupervised_loss_and_metric=loss_and_metric,
136        supervised_train_loader=supervised_train_loader,
137        unsupervised_train_loader=unsupervised_train_loader,
138        supervised_val_loader=supervised_val_loader,
139        unsupervised_val_loader=unsupervised_val_loader,
140        supervised_loss=loss,
141        supervised_loss_and_metric=loss_and_metric,
142        logger=self_training.SelfTrainingTensorboardLogger,
143        mixed_precision=True,
144        log_image_interval=100,
145        compile_model=False,
146        device=device,
147        reinit_teacher=reinit_teacher,
148        save_root=save_root,
149        sampler=sampler,
150    )
151    trainer.fit(n_iterations)

Run domain adapation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:

  • unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
  • semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
  • name: The name for the checkpoint to be trained.
  • unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
  • unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • source_checkpoint: Checkpoint to the initial model trained on the source domain. This is used to initialize the teacher model. If the checkpoint is not given, then both student and teacher model are initialized from scratch. In this case supervised_train_paths and supervised_val_paths have to be given in order to provide training data from the source domain.
  • supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
  • supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
  • confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
  • raw_key: The key that holds the raw data inside of the hdf5 or similar files.
  • label_key: The key that holds the labels inside of the hdf5 files for supervised learning. This is only required if supervised_train_paths and supervised_val_paths are given.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.