synapse_net.training.domain_adaptation
1import os 2import tempfile 3from glob import glob 4from pathlib import Path 5from typing import Optional, Tuple 6 7import mrcfile 8import torch 9import torch_em 10import torch_em.self_training as self_training 11from elf.io import open_file 12from sklearn.model_selection import train_test_split 13 14from .semisupervised_training import get_unsupervised_loader 15from .supervised_training import ( 16 get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files 17) 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models 19from ..inference.util import _Scaler 20 21 22def mean_teacher_adaptation( 23 name: str, 24 unsupervised_train_paths: Tuple[str], 25 unsupervised_val_paths: Tuple[str], 26 patch_shape: Tuple[int, int, int], 27 save_root: Optional[str] = None, 28 source_checkpoint: Optional[str] = None, 29 supervised_train_paths: Optional[Tuple[str]] = None, 30 supervised_val_paths: Optional[Tuple[str]] = None, 31 confidence_threshold: float = 0.9, 32 raw_key: str = "raw", 33 raw_key_supervised: str = "raw", 34 label_key: Optional[str] = None, 35 batch_size: int = 1, 36 lr: float = 1e-4, 37 n_iterations: int = int(1e4), 38 n_samples_train: Optional[int] = None, 39 n_samples_val: Optional[int] = None, 40 train_mask_paths: Optional[Tuple[str]] = None, 41 val_mask_paths: Optional[Tuple[str]] = None, 42 patch_sampler: Optional[callable] = None, 43 pseudo_label_sampler: Optional[callable] = None, 44 device: int = 0, 45 check: bool = False, 46) -> None: 47 """Run domain adaptation to transfer a network trained on a source domain for a supervised 48 segmentation task to perform this task on a different target domain. 49 50 We support different domain adaptation settings: 51 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 52 'supervised_val_paths' are not given. 53 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 54 when 'supervised_train_paths' and 'supervised_val_paths' are given. 55 56 Args: 57 name: The name for the checkpoint to be trained. 58 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 59 for the training data in the target domain. 60 This training data is used for unsupervised learning, so it does not require labels. 61 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 62 for the validation data in the target domain. 63 This validation data is used for unsupervised learning, so it does not require labels. 64 patch_shape: The patch shape used for a training example. 65 In order to run 2d training pass a patch shape with a singleton in the z-axis, 66 e.g. 'patch_shape = [1, 512, 512]'. 67 save_root: Folder where the checkpoint will be saved. 68 source_checkpoint: Checkpoint to the initial model trained on the source domain. 69 This is used to initialize the teacher model. 70 If the checkpoint is not given, then both student and teacher model are initialized 71 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 72 be given in order to provide training data from the source domain. 73 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 74 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 75 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 76 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 77 confidence_threshold: The threshold for filtering data in the unsupervised loss. 78 The label filtering is done based on the uncertainty of network predictions, and only 79 the data with higher certainty than this threshold is used for training. 80 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 81 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 82 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 83 batch_size: The batch size for training. 84 lr: The initial learning rate. 85 n_iterations: The number of iterations to train for. 86 n_samples_train: The number of train samples per epoch. By default this will be estimated 87 based on the patch_shape and size of the volumes used for training. 88 n_samples_val: The number of val samples per epoch. By default this will be estimated 89 based on the patch_shape and size of the volumes used for validation. 90 train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 91 val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 92 patch_sampler: Accept or reject patches based on a condition. 93 pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 94 device: GPU ID for training. 95 check: Whether to check the training and validation loaders instead of running training. 96 """ # noqa 97 assert (supervised_train_paths is None) == (supervised_val_paths is None) 98 is_2d, _ = _determine_ndim(patch_shape) 99 100 if source_checkpoint is None: 101 # training from scratch only makes sense if we have supervised training data 102 # that's why we have the assertion here. 103 assert supervised_train_paths is not None 104 print("Mean teacher training from scratch (AdaMT)") 105 if is_2d: 106 model = get_2d_model(out_channels=2) 107 else: 108 model = get_3d_model(out_channels=2) 109 reinit_teacher = True 110 else: 111 print("Mean teacher training initialized from source model:", source_checkpoint) 112 if os.path.isdir(source_checkpoint): 113 model = torch_em.util.load_model(source_checkpoint) 114 else: 115 model = torch.load(source_checkpoint, weights_only=False) 116 reinit_teacher = False 117 118 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 119 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 120 121 # self training functionality 122 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 123 loss = self_training.DefaultSelfTrainingLoss() 124 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 125 126 unsupervised_train_loader = get_unsupervised_loader( 127 data_paths=unsupervised_train_paths, 128 raw_key=raw_key, 129 patch_shape=patch_shape, 130 batch_size=batch_size, 131 n_samples=n_samples_train, 132 sample_mask_paths=train_mask_paths, 133 sampler=patch_sampler 134 ) 135 unsupervised_val_loader = get_unsupervised_loader( 136 data_paths=unsupervised_val_paths, 137 raw_key=raw_key, 138 patch_shape=patch_shape, 139 batch_size=batch_size, 140 n_samples=n_samples_val, 141 sample_mask_paths=val_mask_paths, 142 sampler=patch_sampler 143 ) 144 145 if supervised_train_paths is not None: 146 assert label_key is not None 147 supervised_train_loader = get_supervised_loader( 148 supervised_train_paths, raw_key_supervised, label_key, 149 patch_shape, batch_size, n_samples=n_samples_train, 150 ) 151 supervised_val_loader = get_supervised_loader( 152 supervised_val_paths, raw_key_supervised, label_key, 153 patch_shape, batch_size, n_samples=n_samples_val, 154 ) 155 else: 156 supervised_train_loader = None 157 supervised_val_loader = None 158 159 if check: 160 from torch_em.util.debug import check_loader 161 check_loader(unsupervised_train_loader, n_samples=4) 162 check_loader(unsupervised_val_loader, n_samples=4) 163 if supervised_train_loader is not None: 164 check_loader(supervised_train_loader, n_samples=4) 165 check_loader(supervised_val_loader, n_samples=4) 166 return 167 168 device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") 169 trainer = self_training.MeanTeacherTrainer( 170 name=name, 171 model=model, 172 optimizer=optimizer, 173 lr_scheduler=scheduler, 174 pseudo_labeler=pseudo_labeler, 175 unsupervised_loss=loss, 176 unsupervised_loss_and_metric=loss_and_metric, 177 supervised_train_loader=supervised_train_loader, 178 unsupervised_train_loader=unsupervised_train_loader, 179 supervised_val_loader=supervised_val_loader, 180 unsupervised_val_loader=unsupervised_val_loader, 181 supervised_loss=loss, 182 supervised_loss_and_metric=loss_and_metric, 183 logger=self_training.SelfTrainingTensorboardLogger, 184 mixed_precision=True, 185 log_image_interval=100, 186 compile_model=False, 187 device=device, 188 reinit_teacher=reinit_teacher, 189 save_root=save_root, 190 sampler=pseudo_label_sampler, 191 ) 192 trainer.fit(n_iterations) 193 194 195# TODO patch shapes for other models 196PATCH_SHAPES = { 197 "vesicles_3d": [48, 256, 256], 198} 199"""@private 200""" 201 202 203def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): 204 files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) 205 if len(files) == 0: 206 raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") 207 208 # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. 209 # And resave the volumes. 210 resave_val_crops = len(files) < 4 211 212 # We only resave the data if we resave val crops or resize the training data 213 resave_data = resave_val_crops or resize_training_data 214 if not resave_data: 215 train_paths, val_paths = train_test_split(files, test_size=val_fraction) 216 return train_paths, val_paths 217 218 train_paths, val_paths = [], [] 219 for file_path in files: 220 file_name = os.path.basename(file_path) 221 data = open_file(file_path, mode="r")["data"][:] 222 223 if resize_training_data: 224 with mrcfile.open(file_path) as f: 225 voxel_size = f.voxel_size 226 voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} 227 scale = compute_scale_from_voxel_size(voxel_size, model_name) 228 scaler = _Scaler(scale, verbose=False) 229 data = scaler.sale_input(data) 230 231 if resave_val_crops: 232 n_slices = data.shape[0] 233 val_slice = int((1.0 - val_fraction) * n_slices) 234 train_data, val_data = data[:val_slice], data[val_slice:] 235 236 train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") 237 with open_file(train_path, mode="w") as f: 238 f.create_dataset("data", data=train_data, compression="lzf") 239 train_paths.append(train_path) 240 241 val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") 242 with open_file(val_path, mode="w") as f: 243 f.create_dataset("data", data=val_data, compression="lzf") 244 val_paths.append(val_path) 245 246 else: 247 output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) 248 with open_file(output_path, mode="w") as f: 249 f.create_dataset("data", data=data, compression="lzf") 250 train_paths.append(output_path) 251 252 if not resave_val_crops: 253 train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) 254 255 return train_paths, val_paths 256 257 258def _parse_patch_shape(patch_shape, model_name): 259 if patch_shape is None: 260 patch_shape = PATCH_SHAPES[model_name] 261 return patch_shape 262 263 264def main(): 265 """@private 266 """ 267 import argparse 268 269 parser = argparse.ArgumentParser( 270 description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" 271 "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" 272 "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa 273 "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa 274 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 275 "Check out the information below for details on the arguments of this function.", 276 formatter_class=argparse.RawTextHelpFormatter 277 ) 278 parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") 279 parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") 280 parser.add_argument("--file_pattern", default="*", 281 help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") 282 parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa 283 available_models = get_available_models() 284 parser.add_argument( 285 "--source_model", 286 default="vesicles_3d", 287 help="The source model used for weight initialization of teacher and student model. " 288 "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n" 289 f"The following source models are available: {available_models}" 290 ) 291 parser.add_argument( 292 "--resize_training_data", action="store_true", 293 help="Whether to resize the training data to fit the voxel size of the source model's trainign data." 294 ) 295 parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") 296 parser.add_argument( 297 "--patch_shape", nargs=3, type=int, 298 help="The patch shape for training. By default the patch shape the source model was trained with is used." 299 ) 300 301 # More optional argument: 302 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 303 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 304 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 305 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 306 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 307 parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") 308 309 args = parser.parse_args() 310 311 source_checkpoint = get_model_path(args.source_model) 312 patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) 313 with tempfile.TemporaryDirectory() as tmp_dir: 314 unsupervised_train_paths, unsupervised_val_paths = _get_paths( 315 args.input_folder, args.file_pattern, args.resize_training_data, 316 args.source_model, tmp_dir, args.val_fraction, 317 ) 318 unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) 319 320 mean_teacher_adaptation( 321 name=args.name, 322 unsupervised_train_paths=unsupervised_train_paths, 323 unsupervised_val_paths=unsupervised_val_paths, 324 patch_shape=patch_shape, 325 source_checkpoint=source_checkpoint, 326 raw_key=raw_key, 327 n_iterations=args.n_iterations, 328 batch_size=args.batch_size, 329 n_samples_train=args.n_samples_train, 330 n_samples_val=args.n_samples_val, 331 check=args.check, 332 save_root=args.save_root, 333 )
def
mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, patch_sampler: Optional[<built-in function callable>] = None, pseudo_label_sampler: Optional[<built-in function callable>] = None, device: int = 0, check: bool = False) -> None:
23def mean_teacher_adaptation( 24 name: str, 25 unsupervised_train_paths: Tuple[str], 26 unsupervised_val_paths: Tuple[str], 27 patch_shape: Tuple[int, int, int], 28 save_root: Optional[str] = None, 29 source_checkpoint: Optional[str] = None, 30 supervised_train_paths: Optional[Tuple[str]] = None, 31 supervised_val_paths: Optional[Tuple[str]] = None, 32 confidence_threshold: float = 0.9, 33 raw_key: str = "raw", 34 raw_key_supervised: str = "raw", 35 label_key: Optional[str] = None, 36 batch_size: int = 1, 37 lr: float = 1e-4, 38 n_iterations: int = int(1e4), 39 n_samples_train: Optional[int] = None, 40 n_samples_val: Optional[int] = None, 41 train_mask_paths: Optional[Tuple[str]] = None, 42 val_mask_paths: Optional[Tuple[str]] = None, 43 patch_sampler: Optional[callable] = None, 44 pseudo_label_sampler: Optional[callable] = None, 45 device: int = 0, 46 check: bool = False, 47) -> None: 48 """Run domain adaptation to transfer a network trained on a source domain for a supervised 49 segmentation task to perform this task on a different target domain. 50 51 We support different domain adaptation settings: 52 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 53 'supervised_val_paths' are not given. 54 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 55 when 'supervised_train_paths' and 'supervised_val_paths' are given. 56 57 Args: 58 name: The name for the checkpoint to be trained. 59 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 60 for the training data in the target domain. 61 This training data is used for unsupervised learning, so it does not require labels. 62 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 63 for the validation data in the target domain. 64 This validation data is used for unsupervised learning, so it does not require labels. 65 patch_shape: The patch shape used for a training example. 66 In order to run 2d training pass a patch shape with a singleton in the z-axis, 67 e.g. 'patch_shape = [1, 512, 512]'. 68 save_root: Folder where the checkpoint will be saved. 69 source_checkpoint: Checkpoint to the initial model trained on the source domain. 70 This is used to initialize the teacher model. 71 If the checkpoint is not given, then both student and teacher model are initialized 72 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 73 be given in order to provide training data from the source domain. 74 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 75 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 76 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 77 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 78 confidence_threshold: The threshold for filtering data in the unsupervised loss. 79 The label filtering is done based on the uncertainty of network predictions, and only 80 the data with higher certainty than this threshold is used for training. 81 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 82 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 83 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 84 batch_size: The batch size for training. 85 lr: The initial learning rate. 86 n_iterations: The number of iterations to train for. 87 n_samples_train: The number of train samples per epoch. By default this will be estimated 88 based on the patch_shape and size of the volumes used for training. 89 n_samples_val: The number of val samples per epoch. By default this will be estimated 90 based on the patch_shape and size of the volumes used for validation. 91 train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 92 val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 93 patch_sampler: Accept or reject patches based on a condition. 94 pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 95 device: GPU ID for training. 96 check: Whether to check the training and validation loaders instead of running training. 97 """ # noqa 98 assert (supervised_train_paths is None) == (supervised_val_paths is None) 99 is_2d, _ = _determine_ndim(patch_shape) 100 101 if source_checkpoint is None: 102 # training from scratch only makes sense if we have supervised training data 103 # that's why we have the assertion here. 104 assert supervised_train_paths is not None 105 print("Mean teacher training from scratch (AdaMT)") 106 if is_2d: 107 model = get_2d_model(out_channels=2) 108 else: 109 model = get_3d_model(out_channels=2) 110 reinit_teacher = True 111 else: 112 print("Mean teacher training initialized from source model:", source_checkpoint) 113 if os.path.isdir(source_checkpoint): 114 model = torch_em.util.load_model(source_checkpoint) 115 else: 116 model = torch.load(source_checkpoint, weights_only=False) 117 reinit_teacher = False 118 119 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 120 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 121 122 # self training functionality 123 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 124 loss = self_training.DefaultSelfTrainingLoss() 125 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 126 127 unsupervised_train_loader = get_unsupervised_loader( 128 data_paths=unsupervised_train_paths, 129 raw_key=raw_key, 130 patch_shape=patch_shape, 131 batch_size=batch_size, 132 n_samples=n_samples_train, 133 sample_mask_paths=train_mask_paths, 134 sampler=patch_sampler 135 ) 136 unsupervised_val_loader = get_unsupervised_loader( 137 data_paths=unsupervised_val_paths, 138 raw_key=raw_key, 139 patch_shape=patch_shape, 140 batch_size=batch_size, 141 n_samples=n_samples_val, 142 sample_mask_paths=val_mask_paths, 143 sampler=patch_sampler 144 ) 145 146 if supervised_train_paths is not None: 147 assert label_key is not None 148 supervised_train_loader = get_supervised_loader( 149 supervised_train_paths, raw_key_supervised, label_key, 150 patch_shape, batch_size, n_samples=n_samples_train, 151 ) 152 supervised_val_loader = get_supervised_loader( 153 supervised_val_paths, raw_key_supervised, label_key, 154 patch_shape, batch_size, n_samples=n_samples_val, 155 ) 156 else: 157 supervised_train_loader = None 158 supervised_val_loader = None 159 160 if check: 161 from torch_em.util.debug import check_loader 162 check_loader(unsupervised_train_loader, n_samples=4) 163 check_loader(unsupervised_val_loader, n_samples=4) 164 if supervised_train_loader is not None: 165 check_loader(supervised_train_loader, n_samples=4) 166 check_loader(supervised_val_loader, n_samples=4) 167 return 168 169 device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") 170 trainer = self_training.MeanTeacherTrainer( 171 name=name, 172 model=model, 173 optimizer=optimizer, 174 lr_scheduler=scheduler, 175 pseudo_labeler=pseudo_labeler, 176 unsupervised_loss=loss, 177 unsupervised_loss_and_metric=loss_and_metric, 178 supervised_train_loader=supervised_train_loader, 179 unsupervised_train_loader=unsupervised_train_loader, 180 supervised_val_loader=supervised_val_loader, 181 unsupervised_val_loader=unsupervised_val_loader, 182 supervised_loss=loss, 183 supervised_loss_and_metric=loss_and_metric, 184 logger=self_training.SelfTrainingTensorboardLogger, 185 mixed_precision=True, 186 log_image_interval=100, 187 compile_model=False, 188 device=device, 189 reinit_teacher=reinit_teacher, 190 save_root=save_root, 191 sampler=pseudo_label_sampler, 192 ) 193 trainer.fit(n_iterations)
Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.
We support different domain adaptation settings:
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
- name: The name for the checkpoint to be trained.
- unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
- unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case
supervised_train_pathsandsupervised_val_pathshave to be given in order to provide training data from the source domain. - supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
- confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
- label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if
supervised_train_pathsandsupervised_val_pathsare given. - batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
- val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
- patch_sampler: Accept or reject patches based on a condition.
- pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
- device: GPU ID for training.
- check: Whether to check the training and validation loaders instead of running training.