synapse_net.training.domain_adaptation
1import os 2from typing import Optional, Tuple 3 4import torch 5import torch_em 6import torch_em.self_training as self_training 7 8from .semisupervised_training import get_unsupervised_loader 9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim 10 11 12def mean_teacher_adaptation( 13 name: str, 14 unsupervised_train_paths: Tuple[str], 15 unsupervised_val_paths: Tuple[str], 16 patch_shape: Tuple[int, int, int], 17 save_root: Optional[str] = None, 18 source_checkpoint: Optional[str] = None, 19 supervised_train_paths: Optional[Tuple[str]] = None, 20 supervised_val_paths: Optional[Tuple[str]] = None, 21 confidence_threshold: float = 0.9, 22 raw_key: str = "raw", 23 raw_key_supervised: str = "raw", 24 label_key: Optional[str] = None, 25 batch_size: int = 1, 26 lr: float = 1e-4, 27 n_iterations: int = int(1e4), 28 n_samples_train: Optional[int] = None, 29 n_samples_val: Optional[int] = None, 30 sampler: Optional[callable] = None, 31) -> None: 32 """Run domain adapation to transfer a network trained on a source domain for a supervised 33 segmentation task to perform this task on a different target domain. 34 35 We support different domain adaptation settings: 36 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 37 'supervised_val_paths' are not given. 38 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 39 when 'supervised_train_paths' and 'supervised_val_paths' are given. 40 41 Args: 42 name: The name for the checkpoint to be trained. 43 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 44 for the training data in the target domain. 45 This training data is used for unsupervised learning, so it does not require labels. 46 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 47 for the validation data in the target domain. 48 This validation data is used for unsupervised learning, so it does not require labels. 49 patch_shape: The patch shape used for a training example. 50 In order to run 2d training pass a patch shape with a singleton in the z-axis, 51 e.g. 'patch_shape = [1, 512, 512]'. 52 save_root: Folder where the checkpoint will be saved. 53 source_checkpoint: Checkpoint to the initial model trained on the source domain. 54 This is used to initialize the teacher model. 55 If the checkpoint is not given, then both student and teacher model are initialized 56 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 57 be given in order to provide training data from the source domain. 58 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 59 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 60 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 61 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 62 confidence_threshold: The threshold for filtering data in the unsupervised loss. 63 The label filtering is done based on the uncertainty of network predictions, and only 64 the data with higher certainty than this threshold is used for training. 65 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 66 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 67 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 68 batch_size: The batch size for training. 69 lr: The initial learning rate. 70 n_iterations: The number of iterations to train for. 71 n_samples_train: The number of train samples per epoch. By default this will be estimated 72 based on the patch_shape and size of the volumes used for training. 73 n_samples_val: The number of val samples per epoch. By default this will be estimated 74 based on the patch_shape and size of the volumes used for validation. 75 """ 76 assert (supervised_train_paths is None) == (supervised_val_paths is None) 77 is_2d, _ = _determine_ndim(patch_shape) 78 79 if source_checkpoint is None: 80 # training from scratch only makes sense if we have supervised training data 81 # that's why we have the assertion here. 82 assert supervised_train_paths is not None 83 print("Mean teacher training from scratch (AdaMT)") 84 if is_2d: 85 model = get_2d_model(out_channels=2) 86 else: 87 model = get_3d_model(out_channels=2) 88 reinit_teacher = True 89 else: 90 print("Mean teacehr training initialized from source model:", source_checkpoint) 91 if os.path.isdir(source_checkpoint): 92 model = torch_em.util.load_model(source_checkpoint) 93 else: 94 model = torch.load(source_checkpoint) 95 reinit_teacher = False 96 97 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 98 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 99 100 # self training functionality 101 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 102 loss = self_training.DefaultSelfTrainingLoss() 103 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 104 105 unsupervised_train_loader = get_unsupervised_loader( 106 unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train 107 ) 108 unsupervised_val_loader = get_unsupervised_loader( 109 unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val 110 ) 111 112 if supervised_train_paths is not None: 113 assert label_key is not None 114 supervised_train_loader = get_supervised_loader( 115 supervised_train_paths, raw_key_supervised, label_key, 116 patch_shape, batch_size, n_samples=n_samples_train, 117 ) 118 supervised_val_loader = get_supervised_loader( 119 supervised_val_paths, raw_key_supervised, label_key, 120 patch_shape, batch_size, n_samples=n_samples_val, 121 ) 122 else: 123 supervised_train_loader = None 124 supervised_val_loader = None 125 126 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 127 trainer = self_training.MeanTeacherTrainer( 128 name=name, 129 model=model, 130 optimizer=optimizer, 131 lr_scheduler=scheduler, 132 pseudo_labeler=pseudo_labeler, 133 unsupervised_loss=loss, 134 unsupervised_loss_and_metric=loss_and_metric, 135 supervised_train_loader=supervised_train_loader, 136 unsupervised_train_loader=unsupervised_train_loader, 137 supervised_val_loader=supervised_val_loader, 138 unsupervised_val_loader=unsupervised_val_loader, 139 supervised_loss=loss, 140 supervised_loss_and_metric=loss_and_metric, 141 logger=self_training.SelfTrainingTensorboardLogger, 142 mixed_precision=True, 143 log_image_interval=100, 144 compile_model=False, 145 device=device, 146 reinit_teacher=reinit_teacher, 147 save_root=save_root, 148 sampler=sampler, 149 ) 150 trainer.fit(n_iterations)
def
mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, sampler: Optional[<built-in function callable>] = None) -> None:
13def mean_teacher_adaptation( 14 name: str, 15 unsupervised_train_paths: Tuple[str], 16 unsupervised_val_paths: Tuple[str], 17 patch_shape: Tuple[int, int, int], 18 save_root: Optional[str] = None, 19 source_checkpoint: Optional[str] = None, 20 supervised_train_paths: Optional[Tuple[str]] = None, 21 supervised_val_paths: Optional[Tuple[str]] = None, 22 confidence_threshold: float = 0.9, 23 raw_key: str = "raw", 24 raw_key_supervised: str = "raw", 25 label_key: Optional[str] = None, 26 batch_size: int = 1, 27 lr: float = 1e-4, 28 n_iterations: int = int(1e4), 29 n_samples_train: Optional[int] = None, 30 n_samples_val: Optional[int] = None, 31 sampler: Optional[callable] = None, 32) -> None: 33 """Run domain adapation to transfer a network trained on a source domain for a supervised 34 segmentation task to perform this task on a different target domain. 35 36 We support different domain adaptation settings: 37 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 38 'supervised_val_paths' are not given. 39 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 40 when 'supervised_train_paths' and 'supervised_val_paths' are given. 41 42 Args: 43 name: The name for the checkpoint to be trained. 44 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 45 for the training data in the target domain. 46 This training data is used for unsupervised learning, so it does not require labels. 47 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 48 for the validation data in the target domain. 49 This validation data is used for unsupervised learning, so it does not require labels. 50 patch_shape: The patch shape used for a training example. 51 In order to run 2d training pass a patch shape with a singleton in the z-axis, 52 e.g. 'patch_shape = [1, 512, 512]'. 53 save_root: Folder where the checkpoint will be saved. 54 source_checkpoint: Checkpoint to the initial model trained on the source domain. 55 This is used to initialize the teacher model. 56 If the checkpoint is not given, then both student and teacher model are initialized 57 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 58 be given in order to provide training data from the source domain. 59 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 60 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 61 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 62 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 63 confidence_threshold: The threshold for filtering data in the unsupervised loss. 64 The label filtering is done based on the uncertainty of network predictions, and only 65 the data with higher certainty than this threshold is used for training. 66 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 67 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 68 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 69 batch_size: The batch size for training. 70 lr: The initial learning rate. 71 n_iterations: The number of iterations to train for. 72 n_samples_train: The number of train samples per epoch. By default this will be estimated 73 based on the patch_shape and size of the volumes used for training. 74 n_samples_val: The number of val samples per epoch. By default this will be estimated 75 based on the patch_shape and size of the volumes used for validation. 76 """ 77 assert (supervised_train_paths is None) == (supervised_val_paths is None) 78 is_2d, _ = _determine_ndim(patch_shape) 79 80 if source_checkpoint is None: 81 # training from scratch only makes sense if we have supervised training data 82 # that's why we have the assertion here. 83 assert supervised_train_paths is not None 84 print("Mean teacher training from scratch (AdaMT)") 85 if is_2d: 86 model = get_2d_model(out_channels=2) 87 else: 88 model = get_3d_model(out_channels=2) 89 reinit_teacher = True 90 else: 91 print("Mean teacehr training initialized from source model:", source_checkpoint) 92 if os.path.isdir(source_checkpoint): 93 model = torch_em.util.load_model(source_checkpoint) 94 else: 95 model = torch.load(source_checkpoint) 96 reinit_teacher = False 97 98 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 99 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 100 101 # self training functionality 102 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 103 loss = self_training.DefaultSelfTrainingLoss() 104 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 105 106 unsupervised_train_loader = get_unsupervised_loader( 107 unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train 108 ) 109 unsupervised_val_loader = get_unsupervised_loader( 110 unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val 111 ) 112 113 if supervised_train_paths is not None: 114 assert label_key is not None 115 supervised_train_loader = get_supervised_loader( 116 supervised_train_paths, raw_key_supervised, label_key, 117 patch_shape, batch_size, n_samples=n_samples_train, 118 ) 119 supervised_val_loader = get_supervised_loader( 120 supervised_val_paths, raw_key_supervised, label_key, 121 patch_shape, batch_size, n_samples=n_samples_val, 122 ) 123 else: 124 supervised_train_loader = None 125 supervised_val_loader = None 126 127 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 128 trainer = self_training.MeanTeacherTrainer( 129 name=name, 130 model=model, 131 optimizer=optimizer, 132 lr_scheduler=scheduler, 133 pseudo_labeler=pseudo_labeler, 134 unsupervised_loss=loss, 135 unsupervised_loss_and_metric=loss_and_metric, 136 supervised_train_loader=supervised_train_loader, 137 unsupervised_train_loader=unsupervised_train_loader, 138 supervised_val_loader=supervised_val_loader, 139 unsupervised_val_loader=unsupervised_val_loader, 140 supervised_loss=loss, 141 supervised_loss_and_metric=loss_and_metric, 142 logger=self_training.SelfTrainingTensorboardLogger, 143 mixed_precision=True, 144 log_image_interval=100, 145 compile_model=False, 146 device=device, 147 reinit_teacher=reinit_teacher, 148 save_root=save_root, 149 sampler=sampler, 150 ) 151 trainer.fit(n_iterations)
Run domain adapation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.
We support different domain adaptation settings:
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
- name: The name for the checkpoint to be trained.
- unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
- unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case
supervised_train_paths
andsupervised_val_paths
have to be given in order to provide training data from the source domain. - supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
- confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
- label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if
supervised_train_paths
andsupervised_val_paths
are given. - batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.