synapse_net.training.domain_adaptation
1import os 2import tempfile 3from glob import glob 4from pathlib import Path 5from typing import Optional, Tuple 6 7import mrcfile 8import torch 9import torch_em 10import torch_em.self_training as self_training 11from elf.io import open_file 12from sklearn.model_selection import train_test_split 13 14from .semisupervised_training import get_unsupervised_loader 15from .supervised_training import ( 16 get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files 17) 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size 19from ..inference.util import _Scaler 20 21def mean_teacher_adaptation( 22 name: str, 23 unsupervised_train_paths: Tuple[str], 24 unsupervised_val_paths: Tuple[str], 25 patch_shape: Tuple[int, int, int], 26 save_root: Optional[str] = None, 27 source_checkpoint: Optional[str] = None, 28 supervised_train_paths: Optional[Tuple[str]] = None, 29 supervised_val_paths: Optional[Tuple[str]] = None, 30 confidence_threshold: float = 0.9, 31 raw_key: str = "raw", 32 raw_key_supervised: str = "raw", 33 label_key: Optional[str] = None, 34 batch_size: int = 1, 35 lr: float = 1e-4, 36 n_iterations: int = int(1e4), 37 n_samples_train: Optional[int] = None, 38 n_samples_val: Optional[int] = None, 39 train_mask_paths: Optional[Tuple[str]] = None, 40 val_mask_paths: Optional[Tuple[str]] = None, 41 patch_sampler: Optional[callable] = None, 42 pseudo_label_sampler: Optional[callable] = None, 43 device: int = 0, 44) -> None: 45 """Run domain adaptation to transfer a network trained on a source domain for a supervised 46 segmentation task to perform this task on a different target domain. 47 48 We support different domain adaptation settings: 49 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 50 'supervised_val_paths' are not given. 51 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 52 when 'supervised_train_paths' and 'supervised_val_paths' are given. 53 54 Args: 55 name: The name for the checkpoint to be trained. 56 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 57 for the training data in the target domain. 58 This training data is used for unsupervised learning, so it does not require labels. 59 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 60 for the validation data in the target domain. 61 This validation data is used for unsupervised learning, so it does not require labels. 62 patch_shape: The patch shape used for a training example. 63 In order to run 2d training pass a patch shape with a singleton in the z-axis, 64 e.g. 'patch_shape = [1, 512, 512]'. 65 save_root: Folder where the checkpoint will be saved. 66 source_checkpoint: Checkpoint to the initial model trained on the source domain. 67 This is used to initialize the teacher model. 68 If the checkpoint is not given, then both student and teacher model are initialized 69 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 70 be given in order to provide training data from the source domain. 71 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 72 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 73 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 74 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 75 confidence_threshold: The threshold for filtering data in the unsupervised loss. 76 The label filtering is done based on the uncertainty of network predictions, and only 77 the data with higher certainty than this threshold is used for training. 78 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 79 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 80 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 81 batch_size: The batch size for training. 82 lr: The initial learning rate. 83 n_iterations: The number of iterations to train for. 84 n_samples_train: The number of train samples per epoch. By default this will be estimated 85 based on the patch_shape and size of the volumes used for training. 86 n_samples_val: The number of val samples per epoch. By default this will be estimated 87 based on the patch_shape and size of the volumes used for validation. 88 train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 89 val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 90 patch_sampler: Accept or reject patches based on a condition. 91 pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 92 device: GPU ID for training. 93 """ 94 assert (supervised_train_paths is None) == (supervised_val_paths is None) 95 is_2d, _ = _determine_ndim(patch_shape) 96 97 if source_checkpoint is None: 98 # training from scratch only makes sense if we have supervised training data 99 # that's why we have the assertion here. 100 assert supervised_train_paths is not None 101 print("Mean teacher training from scratch (AdaMT)") 102 if is_2d: 103 model = get_2d_model(out_channels=2) 104 else: 105 model = get_3d_model(out_channels=2) 106 reinit_teacher = True 107 else: 108 print("Mean teacher training initialized from source model:", source_checkpoint) 109 if os.path.isdir(source_checkpoint): 110 model = torch_em.util.load_model(source_checkpoint) 111 else: 112 model = torch.load(source_checkpoint, weights_only=False) 113 reinit_teacher = False 114 115 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 116 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 117 118 # self training functionality 119 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 120 loss = self_training.DefaultSelfTrainingLoss() 121 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 122 123 unsupervised_train_loader = get_unsupervised_loader( 124 data_paths=unsupervised_train_paths, 125 raw_key=raw_key, 126 patch_shape=patch_shape, 127 batch_size=batch_size, 128 n_samples=n_samples_train, 129 sample_mask_paths=train_mask_paths, 130 sampler=patch_sampler 131 ) 132 unsupervised_val_loader = get_unsupervised_loader( 133 data_paths=unsupervised_val_paths, 134 raw_key=raw_key, 135 patch_shape=patch_shape, 136 batch_size=batch_size, 137 n_samples=n_samples_val, 138 sample_mask_paths=val_mask_paths, 139 sampler=patch_sampler 140 ) 141 142 if supervised_train_paths is not None: 143 assert label_key is not None 144 supervised_train_loader = get_supervised_loader( 145 supervised_train_paths, raw_key_supervised, label_key, 146 patch_shape, batch_size, n_samples=n_samples_train, 147 ) 148 supervised_val_loader = get_supervised_loader( 149 supervised_val_paths, raw_key_supervised, label_key, 150 patch_shape, batch_size, n_samples=n_samples_val, 151 ) 152 else: 153 supervised_train_loader = None 154 supervised_val_loader = None 155 156 device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") 157 trainer = self_training.MeanTeacherTrainer( 158 name=name, 159 model=model, 160 optimizer=optimizer, 161 lr_scheduler=scheduler, 162 pseudo_labeler=pseudo_labeler, 163 unsupervised_loss=loss, 164 unsupervised_loss_and_metric=loss_and_metric, 165 supervised_train_loader=supervised_train_loader, 166 unsupervised_train_loader=unsupervised_train_loader, 167 supervised_val_loader=supervised_val_loader, 168 unsupervised_val_loader=unsupervised_val_loader, 169 supervised_loss=loss, 170 supervised_loss_and_metric=loss_and_metric, 171 logger=self_training.SelfTrainingTensorboardLogger, 172 mixed_precision=True, 173 log_image_interval=100, 174 compile_model=False, 175 device=device, 176 reinit_teacher=reinit_teacher, 177 save_root=save_root, 178 sampler=pseudo_label_sampler, 179 ) 180 trainer.fit(n_iterations) 181 182 183# TODO patch shapes for other models 184PATCH_SHAPES = { 185 "vesicles_3d": [48, 256, 256], 186} 187"""@private 188""" 189 190 191def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): 192 files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) 193 if len(files) == 0: 194 raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") 195 196 # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. 197 # And resave the volumes. 198 resave_val_crops = len(files) < 4 199 200 # We only resave the data if we resave val crops or resize the training data 201 resave_data = resave_val_crops or resize_training_data 202 if not resave_data: 203 train_paths, val_paths = train_test_split(files, test_size=val_fraction) 204 return train_paths, val_paths 205 206 train_paths, val_paths = [], [] 207 for file_path in files: 208 file_name = os.path.basename(file_path) 209 data = open_file(file_path, mode="r")["data"][:] 210 211 if resize_training_data: 212 with mrcfile.open(file_path) as f: 213 voxel_size = f.voxel_size 214 voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} 215 scale = compute_scale_from_voxel_size(voxel_size, model_name) 216 scaler = _Scaler(scale, verbose=False) 217 data = scaler.sale_input(data) 218 219 if resave_val_crops: 220 n_slices = data.shape[0] 221 val_slice = int((1.0 - val_fraction) * n_slices) 222 train_data, val_data = data[:val_slice], data[val_slice:] 223 224 train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") 225 with open_file(train_path, mode="w") as f: 226 f.create_dataset("data", data=train_data, compression="lzf") 227 train_paths.append(train_path) 228 229 val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") 230 with open_file(val_path, mode="w") as f: 231 f.create_dataset("data", data=val_data, compression="lzf") 232 val_paths.append(val_path) 233 234 else: 235 output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) 236 with open_file(output_path, mode="w") as f: 237 f.create_dataset("data", data=data, compression="lzf") 238 train_paths.append(output_path) 239 240 if not resave_val_crops: 241 train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) 242 243 return train_paths, val_paths 244 245 246def _parse_patch_shape(patch_shape, model_name): 247 if patch_shape is None: 248 patch_shape = PATCH_SHAPES[model_name] 249 return patch_shape 250 251def main(): 252 """@private 253 """ 254 import argparse 255 256 parser = argparse.ArgumentParser( 257 description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" 258 "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" 259 "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa 260 "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa 261 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 262 "Check out the information below for details on the arguments of this function.", 263 formatter_class=argparse.RawTextHelpFormatter 264 ) 265 parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") 266 parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") 267 parser.add_argument("--file_pattern", default="*", 268 help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") 269 parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa 270 parser.add_argument( 271 "--source_model", 272 default="vesicles_3d", 273 help="The source model used for weight initialization of teacher and student model. " 274 "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used." 275 ) 276 parser.add_argument( 277 "--resize_training_data", action="store_true", 278 help="Whether to resize the training data to fit the voxel size of the source model's trainign data." 279 ) 280 parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") 281 parser.add_argument( 282 "--patch_shape", nargs=3, type=int, 283 help="The patch shape for training. By default the patch shape the source model was trained with is used." 284 ) 285 286 # More optional argument: 287 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 288 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 289 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 290 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 291 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 292 293 args = parser.parse_args() 294 295 source_checkpoint = get_model_path(args.source_model) 296 patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) 297 with tempfile.TemporaryDirectory() as tmp_dir: 298 unsupervised_train_paths, unsupervised_val_paths = _get_paths( 299 args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction, 300 ) 301 unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) 302 303 mean_teacher_adaptation( 304 name=args.name, 305 unsupervised_train_paths=unsupervised_train_paths, 306 unsupervised_val_paths=unsupervised_val_paths, 307 patch_shape=patch_shape, 308 source_checkpoint=source_checkpoint, 309 raw_key=raw_key, 310 n_iterations=args.n_iterations, 311 batch_size=args.batch_size, 312 n_samples_train=args.n_samples_train, 313 n_samples_val=args.n_samples_val, 314 check=args.check, 315 )
def
mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, patch_sampler: Optional[<built-in function callable>] = None, pseudo_label_sampler: Optional[<built-in function callable>] = None, device: int = 0) -> None:
22def mean_teacher_adaptation( 23 name: str, 24 unsupervised_train_paths: Tuple[str], 25 unsupervised_val_paths: Tuple[str], 26 patch_shape: Tuple[int, int, int], 27 save_root: Optional[str] = None, 28 source_checkpoint: Optional[str] = None, 29 supervised_train_paths: Optional[Tuple[str]] = None, 30 supervised_val_paths: Optional[Tuple[str]] = None, 31 confidence_threshold: float = 0.9, 32 raw_key: str = "raw", 33 raw_key_supervised: str = "raw", 34 label_key: Optional[str] = None, 35 batch_size: int = 1, 36 lr: float = 1e-4, 37 n_iterations: int = int(1e4), 38 n_samples_train: Optional[int] = None, 39 n_samples_val: Optional[int] = None, 40 train_mask_paths: Optional[Tuple[str]] = None, 41 val_mask_paths: Optional[Tuple[str]] = None, 42 patch_sampler: Optional[callable] = None, 43 pseudo_label_sampler: Optional[callable] = None, 44 device: int = 0, 45) -> None: 46 """Run domain adaptation to transfer a network trained on a source domain for a supervised 47 segmentation task to perform this task on a different target domain. 48 49 We support different domain adaptation settings: 50 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 51 'supervised_val_paths' are not given. 52 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 53 when 'supervised_train_paths' and 'supervised_val_paths' are given. 54 55 Args: 56 name: The name for the checkpoint to be trained. 57 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 58 for the training data in the target domain. 59 This training data is used for unsupervised learning, so it does not require labels. 60 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 61 for the validation data in the target domain. 62 This validation data is used for unsupervised learning, so it does not require labels. 63 patch_shape: The patch shape used for a training example. 64 In order to run 2d training pass a patch shape with a singleton in the z-axis, 65 e.g. 'patch_shape = [1, 512, 512]'. 66 save_root: Folder where the checkpoint will be saved. 67 source_checkpoint: Checkpoint to the initial model trained on the source domain. 68 This is used to initialize the teacher model. 69 If the checkpoint is not given, then both student and teacher model are initialized 70 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 71 be given in order to provide training data from the source domain. 72 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 73 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 74 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 75 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 76 confidence_threshold: The threshold for filtering data in the unsupervised loss. 77 The label filtering is done based on the uncertainty of network predictions, and only 78 the data with higher certainty than this threshold is used for training. 79 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 80 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 81 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 82 batch_size: The batch size for training. 83 lr: The initial learning rate. 84 n_iterations: The number of iterations to train for. 85 n_samples_train: The number of train samples per epoch. By default this will be estimated 86 based on the patch_shape and size of the volumes used for training. 87 n_samples_val: The number of val samples per epoch. By default this will be estimated 88 based on the patch_shape and size of the volumes used for validation. 89 train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 90 val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 91 patch_sampler: Accept or reject patches based on a condition. 92 pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 93 device: GPU ID for training. 94 """ 95 assert (supervised_train_paths is None) == (supervised_val_paths is None) 96 is_2d, _ = _determine_ndim(patch_shape) 97 98 if source_checkpoint is None: 99 # training from scratch only makes sense if we have supervised training data 100 # that's why we have the assertion here. 101 assert supervised_train_paths is not None 102 print("Mean teacher training from scratch (AdaMT)") 103 if is_2d: 104 model = get_2d_model(out_channels=2) 105 else: 106 model = get_3d_model(out_channels=2) 107 reinit_teacher = True 108 else: 109 print("Mean teacher training initialized from source model:", source_checkpoint) 110 if os.path.isdir(source_checkpoint): 111 model = torch_em.util.load_model(source_checkpoint) 112 else: 113 model = torch.load(source_checkpoint, weights_only=False) 114 reinit_teacher = False 115 116 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 117 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 118 119 # self training functionality 120 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 121 loss = self_training.DefaultSelfTrainingLoss() 122 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 123 124 unsupervised_train_loader = get_unsupervised_loader( 125 data_paths=unsupervised_train_paths, 126 raw_key=raw_key, 127 patch_shape=patch_shape, 128 batch_size=batch_size, 129 n_samples=n_samples_train, 130 sample_mask_paths=train_mask_paths, 131 sampler=patch_sampler 132 ) 133 unsupervised_val_loader = get_unsupervised_loader( 134 data_paths=unsupervised_val_paths, 135 raw_key=raw_key, 136 patch_shape=patch_shape, 137 batch_size=batch_size, 138 n_samples=n_samples_val, 139 sample_mask_paths=val_mask_paths, 140 sampler=patch_sampler 141 ) 142 143 if supervised_train_paths is not None: 144 assert label_key is not None 145 supervised_train_loader = get_supervised_loader( 146 supervised_train_paths, raw_key_supervised, label_key, 147 patch_shape, batch_size, n_samples=n_samples_train, 148 ) 149 supervised_val_loader = get_supervised_loader( 150 supervised_val_paths, raw_key_supervised, label_key, 151 patch_shape, batch_size, n_samples=n_samples_val, 152 ) 153 else: 154 supervised_train_loader = None 155 supervised_val_loader = None 156 157 device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu") 158 trainer = self_training.MeanTeacherTrainer( 159 name=name, 160 model=model, 161 optimizer=optimizer, 162 lr_scheduler=scheduler, 163 pseudo_labeler=pseudo_labeler, 164 unsupervised_loss=loss, 165 unsupervised_loss_and_metric=loss_and_metric, 166 supervised_train_loader=supervised_train_loader, 167 unsupervised_train_loader=unsupervised_train_loader, 168 supervised_val_loader=supervised_val_loader, 169 unsupervised_val_loader=unsupervised_val_loader, 170 supervised_loss=loss, 171 supervised_loss_and_metric=loss_and_metric, 172 logger=self_training.SelfTrainingTensorboardLogger, 173 mixed_precision=True, 174 log_image_interval=100, 175 compile_model=False, 176 device=device, 177 reinit_teacher=reinit_teacher, 178 save_root=save_root, 179 sampler=pseudo_label_sampler, 180 ) 181 trainer.fit(n_iterations)
Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.
We support different domain adaptation settings:
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
- name: The name for the checkpoint to be trained.
- unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
- unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case
supervised_train_paths
andsupervised_val_paths
have to be given in order to provide training data from the source domain. - supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
- confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
- label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if
supervised_train_paths
andsupervised_val_paths
are given. - batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
- val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
- patch_sampler: Accept or reject patches based on a condition.
- pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
- device: GPU ID for training.