synapse_net.training.domain_adaptation
1import os 2import tempfile 3from glob import glob 4from pathlib import Path 5from typing import Optional, Tuple 6 7import mrcfile 8import torch 9import torch_em 10import torch_em.self_training as self_training 11from elf.io import open_file 12from sklearn.model_selection import train_test_split 13 14from .semisupervised_training import get_unsupervised_loader 15from .supervised_training import ( 16 get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files 17) 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models 19from ..inference.util import _Scaler 20 21# configure weak augmentations 22from torch_em.transform.invertible_augmentations import DEFAULT_WEAK_AUGMENTATIONS 23 24# TODO - test settings - fixed kernel size for `RandomGaussianBlur`, fixed std for `RandomGaussianBlur` 25DEFAULT_WEAK_AUGMENTATIONS["intensity"] = { 26 "RandomGaussianBlur": {"kernel_size": (19, 19), "sigma": (0.1, 3.0)}, 27 "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)}, 28} 29 30def mean_teacher_adaptation( 31 name: str, 32 unsupervised_train_paths: Tuple[str], 33 unsupervised_val_paths: Tuple[str], 34 patch_shape: Tuple[int, int, int], 35 save_root: Optional[str] = None, 36 source_checkpoint: Optional[str] = None, 37 supervised_train_paths: Optional[Tuple[str]] = None, 38 supervised_val_paths: Optional[Tuple[str]] = None, 39 confidence_threshold: float = 0.9, 40 raw_key: str = "raw", 41 raw_key_supervised: str = "raw", 42 label_key: Optional[str] = None, 43 batch_size: int = 1, 44 lr: float = 1e-4, 45 n_iterations: int = int(1e4), 46 n_samples_train: Optional[int] = None, 47 n_samples_val: Optional[int] = None, 48 train_mask_paths: Optional[Tuple[str]] = None, 49 val_mask_paths: Optional[Tuple[str]] = None, 50 sample_mask_key: Optional[str] = None, 51 unsupervised_sampler: Optional[callable] = None, 52 supervised_sampler: Optional[callable] = None, 53 check: bool = False, 54) -> None: 55 """Run domain adaptation to transfer a network trained on a source domain for a supervised 56 segmentation task to perform this task on a different target domain. 57 58 We support different domain adaptation settings: 59 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 60 'supervised_val_paths' are not given. 61 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 62 when 'supervised_train_paths' and 'supervised_val_paths' are given. 63 64 Args: 65 name: The name for the checkpoint to be trained. 66 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 67 for the training data in the target domain. 68 This training data is used for unsupervised learning, so it does not require labels. 69 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 70 for the validation data in the target domain. 71 This validation data is used for unsupervised learning, so it does not require labels. 72 patch_shape: The patch shape used for a training example. 73 In order to run 2d training pass a patch shape with a singleton in the z-axis, 74 e.g. 'patch_shape = [1, 512, 512]'. 75 save_root: Folder where the checkpoint will be saved. 76 source_checkpoint: Checkpoint to the initial model trained on the source domain. 77 This is used to initialize the teacher model. 78 If the checkpoint is not given, then both student and teacher model are initialized 79 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 80 be given in order to provide training data from the source domain. 81 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 82 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 83 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 84 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 85 confidence_threshold: The threshold for filtering data in the unsupervised loss. 86 The label filtering is done based on the uncertainty of network predictions, and only 87 the data with higher certainty than this threshold is used for training. 88 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 89 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 90 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 91 batch_size: The batch size for training. 92 lr: The initial learning rate. 93 n_iterations: The number of iterations to train for. 94 n_samples_train: The number of train samples per epoch. By default this will be estimated 95 based on the patch_shape and size of the volumes used for training. 96 n_samples_val: The number of val samples per epoch. By default this will be estimated 97 based on the patch_shape and size of the volumes used for validation. 98 train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training. 99 val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation. 100 sample_mask_key: The key to the sample mask dataset inside each file. 101 unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream. 102 supervised_sampler: Sampler to accept or reject patches for the supervised data stream. 103 Pass `False` to disable. 104 check: Whether to check the training and validation loaders instead of running training. 105 """ # noqa 106 assert (supervised_train_paths is None) == (supervised_val_paths is None) 107 is_2d, _ = _determine_ndim(patch_shape) 108 109 if source_checkpoint is None: 110 # training from scratch only makes sense if we have supervised training data 111 # that's why we have the assertion here. 112 assert supervised_train_paths is not None 113 print("Mean teacher training from scratch (AdaMT)") 114 if is_2d: 115 model = get_2d_model(out_channels=2) 116 else: 117 model = get_3d_model(out_channels=2) 118 reinit_teacher = True 119 else: 120 print("Mean teacher training initialized from source model:", source_checkpoint) 121 if os.path.isdir(source_checkpoint): 122 model = torch_em.util.load_model(source_checkpoint) 123 else: 124 model = torch.load(source_checkpoint, weights_only=False) 125 reinit_teacher = False 126 127 optimizer = torch.optim.Adam(model.parameters(), lr=lr) 128 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 129 130 # self training functionality 131 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 132 loss = self_training.SelfTrainingLossWithInvertibleAugmentations() 133 loss_and_metric = self_training.SelfTrainingLossAndMetricWithInvertibleAugmentations() 134 135 ndim = 2 if is_2d else 3 136 augmenters = torch_em.transform.invertible_augmentations.MeanTeacherAugmenters(ndim=ndim) 137 138 unsupervised_train_loader = get_unsupervised_loader( 139 data_paths=unsupervised_train_paths, 140 raw_key=raw_key, 141 patch_shape=patch_shape, 142 batch_size=batch_size, 143 n_samples=n_samples_train, 144 sample_mask_paths=train_mask_paths, 145 sample_mask_key=sample_mask_key, 146 sampler=unsupervised_sampler, 147 ) 148 unsupervised_val_loader = get_unsupervised_loader( 149 data_paths=unsupervised_val_paths, 150 raw_key=raw_key, 151 patch_shape=patch_shape, 152 batch_size=batch_size, 153 n_samples=n_samples_val, 154 sample_mask_paths=val_mask_paths, 155 sample_mask_key=sample_mask_key, 156 sampler=unsupervised_sampler, 157 ) 158 159 if supervised_train_paths is not None: 160 assert label_key is not None 161 supervised_train_loader = get_supervised_loader( 162 supervised_train_paths, raw_key_supervised, label_key, 163 patch_shape, batch_size, n_samples=n_samples_train, 164 sampler=supervised_sampler, 165 ) 166 supervised_val_loader = get_supervised_loader( 167 supervised_val_paths, raw_key_supervised, label_key, 168 patch_shape, batch_size, n_samples=n_samples_val, 169 sampler=supervised_sampler, 170 ) 171 else: 172 supervised_train_loader = None 173 supervised_val_loader = None 174 175 if check: 176 from torch_em.util.debug import check_loader 177 check_loader(unsupervised_train_loader, n_samples=4) 178 check_loader(unsupervised_val_loader, n_samples=4) 179 if supervised_train_loader is not None: 180 check_loader(supervised_train_loader, n_samples=4) 181 check_loader(supervised_val_loader, n_samples=4) 182 return 183 184 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 185 trainer = self_training.MeanTeacherTrainerWithInvertibleAugmentations( 186 name=name, 187 model=model, 188 optimizer=optimizer, 189 lr_scheduler=scheduler, 190 pseudo_labeler=pseudo_labeler, 191 unsupervised_loss=loss, 192 unsupervised_loss_and_metric=loss_and_metric, 193 supervised_train_loader=supervised_train_loader, 194 unsupervised_train_loader=unsupervised_train_loader, 195 supervised_val_loader=supervised_val_loader, 196 unsupervised_val_loader=unsupervised_val_loader, 197 supervised_loss=loss, 198 supervised_loss_and_metric=loss_and_metric, 199 logger=self_training.SelfTrainingTensorboardLogger, 200 mixed_precision=True, 201 log_image_interval=100, 202 compile_model=False, 203 device=device, 204 reinit_teacher=reinit_teacher, 205 save_root=save_root, 206 augmenter=augmenters, 207 ) 208 trainer.fit(n_iterations) 209 210 211# TODO patch shapes for other models 212PATCH_SHAPES = { 213 "vesicles_3d": [48, 256, 256], 214} 215"""@private 216""" 217 218 219def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): 220 files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) 221 if len(files) == 0: 222 raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") 223 224 # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. 225 # And resave the volumes. 226 resave_val_crops = len(files) < 4 227 228 # We only resave the data if we resave val crops or resize the training data 229 resave_data = resave_val_crops or resize_training_data 230 if not resave_data: 231 train_paths, val_paths = train_test_split(files, test_size=val_fraction) 232 return train_paths, val_paths 233 234 train_paths, val_paths = [], [] 235 for file_path in files: 236 file_name = os.path.basename(file_path) 237 data = open_file(file_path, mode="r")["data"][:] 238 239 if resize_training_data: 240 with mrcfile.open(file_path) as f: 241 voxel_size = f.voxel_size 242 voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} 243 scale = compute_scale_from_voxel_size(voxel_size, model_name) 244 scaler = _Scaler(scale, verbose=False) 245 data = scaler.sale_input(data) 246 247 if resave_val_crops: 248 n_slices = data.shape[0] 249 val_slice = int((1.0 - val_fraction) * n_slices) 250 train_data, val_data = data[:val_slice], data[val_slice:] 251 252 train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") 253 with open_file(train_path, mode="w") as f: 254 f.create_dataset("data", data=train_data, compression="lzf") 255 train_paths.append(train_path) 256 257 val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") 258 with open_file(val_path, mode="w") as f: 259 f.create_dataset("data", data=val_data, compression="lzf") 260 val_paths.append(val_path) 261 262 else: 263 output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) 264 with open_file(output_path, mode="w") as f: 265 f.create_dataset("data", data=data, compression="lzf") 266 train_paths.append(output_path) 267 268 if not resave_val_crops: 269 train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) 270 271 return train_paths, val_paths 272 273 274def _parse_patch_shape(patch_shape, model_name): 275 if patch_shape is None: 276 patch_shape = PATCH_SHAPES[model_name] 277 return patch_shape 278 279 280def main(): 281 """@private 282 """ 283 import argparse 284 285 parser = argparse.ArgumentParser( 286 description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" 287 "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" 288 "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa 289 "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa 290 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 291 "Check out the information below for details on the arguments of this function.", 292 formatter_class=argparse.RawTextHelpFormatter 293 ) 294 parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") 295 parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") 296 parser.add_argument("--file_pattern", default="*", 297 help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") 298 parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa 299 available_models = get_available_models() 300 parser.add_argument( 301 "--source_model", 302 default="vesicles_3d", 303 help="The source model used for weight initialization of teacher and student model. " 304 "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n" 305 f"The following source models are available: {available_models}" 306 ) 307 parser.add_argument( 308 "--resize_training_data", action="store_true", 309 help="Whether to resize the training data to fit the voxel size of the source model's trainign data." 310 ) 311 parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") 312 parser.add_argument( 313 "--patch_shape", nargs=3, type=int, 314 help="The patch shape for training. By default the patch shape the source model was trained with is used." 315 ) 316 317 # More optional argument: 318 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 319 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 320 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 321 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 322 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 323 parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") 324 325 args = parser.parse_args() 326 327 source_checkpoint = get_model_path(args.source_model) 328 patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) 329 with tempfile.TemporaryDirectory() as tmp_dir: 330 unsupervised_train_paths, unsupervised_val_paths = _get_paths( 331 args.input_folder, args.file_pattern, args.resize_training_data, 332 args.source_model, tmp_dir, args.val_fraction, 333 ) 334 unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) 335 336 mean_teacher_adaptation( 337 name=args.name, 338 unsupervised_train_paths=unsupervised_train_paths, 339 unsupervised_val_paths=unsupervised_val_paths, 340 patch_shape=patch_shape, 341 source_checkpoint=source_checkpoint, 342 raw_key=raw_key, 343 n_iterations=args.n_iterations, 344 batch_size=args.batch_size, 345 n_samples_train=args.n_samples_train, 346 n_samples_val=args.n_samples_val, 347 check=args.check, 348 save_root=args.save_root, 349 )
def
mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, sample_mask_key: Optional[str] = None, unsupervised_sampler: Optional[<built-in function callable>] = None, supervised_sampler: Optional[<built-in function callable>] = None, check: bool = False) -> None:
31def mean_teacher_adaptation( 32 name: str, 33 unsupervised_train_paths: Tuple[str], 34 unsupervised_val_paths: Tuple[str], 35 patch_shape: Tuple[int, int, int], 36 save_root: Optional[str] = None, 37 source_checkpoint: Optional[str] = None, 38 supervised_train_paths: Optional[Tuple[str]] = None, 39 supervised_val_paths: Optional[Tuple[str]] = None, 40 confidence_threshold: float = 0.9, 41 raw_key: str = "raw", 42 raw_key_supervised: str = "raw", 43 label_key: Optional[str] = None, 44 batch_size: int = 1, 45 lr: float = 1e-4, 46 n_iterations: int = int(1e4), 47 n_samples_train: Optional[int] = None, 48 n_samples_val: Optional[int] = None, 49 train_mask_paths: Optional[Tuple[str]] = None, 50 val_mask_paths: Optional[Tuple[str]] = None, 51 sample_mask_key: Optional[str] = None, 52 unsupervised_sampler: Optional[callable] = None, 53 supervised_sampler: Optional[callable] = None, 54 check: bool = False, 55) -> None: 56 """Run domain adaptation to transfer a network trained on a source domain for a supervised 57 segmentation task to perform this task on a different target domain. 58 59 We support different domain adaptation settings: 60 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 61 'supervised_val_paths' are not given. 62 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 63 when 'supervised_train_paths' and 'supervised_val_paths' are given. 64 65 Args: 66 name: The name for the checkpoint to be trained. 67 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 68 for the training data in the target domain. 69 This training data is used for unsupervised learning, so it does not require labels. 70 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 71 for the validation data in the target domain. 72 This validation data is used for unsupervised learning, so it does not require labels. 73 patch_shape: The patch shape used for a training example. 74 In order to run 2d training pass a patch shape with a singleton in the z-axis, 75 e.g. 'patch_shape = [1, 512, 512]'. 76 save_root: Folder where the checkpoint will be saved. 77 source_checkpoint: Checkpoint to the initial model trained on the source domain. 78 This is used to initialize the teacher model. 79 If the checkpoint is not given, then both student and teacher model are initialized 80 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 81 be given in order to provide training data from the source domain. 82 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 83 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 84 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 85 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 86 confidence_threshold: The threshold for filtering data in the unsupervised loss. 87 The label filtering is done based on the uncertainty of network predictions, and only 88 the data with higher certainty than this threshold is used for training. 89 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 90 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 91 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 92 batch_size: The batch size for training. 93 lr: The initial learning rate. 94 n_iterations: The number of iterations to train for. 95 n_samples_train: The number of train samples per epoch. By default this will be estimated 96 based on the patch_shape and size of the volumes used for training. 97 n_samples_val: The number of val samples per epoch. By default this will be estimated 98 based on the patch_shape and size of the volumes used for validation. 99 train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training. 100 val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation. 101 sample_mask_key: The key to the sample mask dataset inside each file. 102 unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream. 103 supervised_sampler: Sampler to accept or reject patches for the supervised data stream. 104 Pass `False` to disable. 105 check: Whether to check the training and validation loaders instead of running training. 106 """ # noqa 107 assert (supervised_train_paths is None) == (supervised_val_paths is None) 108 is_2d, _ = _determine_ndim(patch_shape) 109 110 if source_checkpoint is None: 111 # training from scratch only makes sense if we have supervised training data 112 # that's why we have the assertion here. 113 assert supervised_train_paths is not None 114 print("Mean teacher training from scratch (AdaMT)") 115 if is_2d: 116 model = get_2d_model(out_channels=2) 117 else: 118 model = get_3d_model(out_channels=2) 119 reinit_teacher = True 120 else: 121 print("Mean teacher training initialized from source model:", source_checkpoint) 122 if os.path.isdir(source_checkpoint): 123 model = torch_em.util.load_model(source_checkpoint) 124 else: 125 model = torch.load(source_checkpoint, weights_only=False) 126 reinit_teacher = False 127 128 optimizer = torch.optim.Adam(model.parameters(), lr=lr) 129 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 130 131 # self training functionality 132 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 133 loss = self_training.SelfTrainingLossWithInvertibleAugmentations() 134 loss_and_metric = self_training.SelfTrainingLossAndMetricWithInvertibleAugmentations() 135 136 ndim = 2 if is_2d else 3 137 augmenters = torch_em.transform.invertible_augmentations.MeanTeacherAugmenters(ndim=ndim) 138 139 unsupervised_train_loader = get_unsupervised_loader( 140 data_paths=unsupervised_train_paths, 141 raw_key=raw_key, 142 patch_shape=patch_shape, 143 batch_size=batch_size, 144 n_samples=n_samples_train, 145 sample_mask_paths=train_mask_paths, 146 sample_mask_key=sample_mask_key, 147 sampler=unsupervised_sampler, 148 ) 149 unsupervised_val_loader = get_unsupervised_loader( 150 data_paths=unsupervised_val_paths, 151 raw_key=raw_key, 152 patch_shape=patch_shape, 153 batch_size=batch_size, 154 n_samples=n_samples_val, 155 sample_mask_paths=val_mask_paths, 156 sample_mask_key=sample_mask_key, 157 sampler=unsupervised_sampler, 158 ) 159 160 if supervised_train_paths is not None: 161 assert label_key is not None 162 supervised_train_loader = get_supervised_loader( 163 supervised_train_paths, raw_key_supervised, label_key, 164 patch_shape, batch_size, n_samples=n_samples_train, 165 sampler=supervised_sampler, 166 ) 167 supervised_val_loader = get_supervised_loader( 168 supervised_val_paths, raw_key_supervised, label_key, 169 patch_shape, batch_size, n_samples=n_samples_val, 170 sampler=supervised_sampler, 171 ) 172 else: 173 supervised_train_loader = None 174 supervised_val_loader = None 175 176 if check: 177 from torch_em.util.debug import check_loader 178 check_loader(unsupervised_train_loader, n_samples=4) 179 check_loader(unsupervised_val_loader, n_samples=4) 180 if supervised_train_loader is not None: 181 check_loader(supervised_train_loader, n_samples=4) 182 check_loader(supervised_val_loader, n_samples=4) 183 return 184 185 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 186 trainer = self_training.MeanTeacherTrainerWithInvertibleAugmentations( 187 name=name, 188 model=model, 189 optimizer=optimizer, 190 lr_scheduler=scheduler, 191 pseudo_labeler=pseudo_labeler, 192 unsupervised_loss=loss, 193 unsupervised_loss_and_metric=loss_and_metric, 194 supervised_train_loader=supervised_train_loader, 195 unsupervised_train_loader=unsupervised_train_loader, 196 supervised_val_loader=supervised_val_loader, 197 unsupervised_val_loader=unsupervised_val_loader, 198 supervised_loss=loss, 199 supervised_loss_and_metric=loss_and_metric, 200 logger=self_training.SelfTrainingTensorboardLogger, 201 mixed_precision=True, 202 log_image_interval=100, 203 compile_model=False, 204 device=device, 205 reinit_teacher=reinit_teacher, 206 save_root=save_root, 207 augmenter=augmenters, 208 ) 209 trainer.fit(n_iterations)
Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.
We support different domain adaptation settings:
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
- name: The name for the checkpoint to be trained.
- unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
- unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case
supervised_train_pathsandsupervised_val_pathshave to be given in order to provide training data from the source domain. - supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
- confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
- label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if
supervised_train_pathsandsupervised_val_pathsare given. - batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training.
- val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation.
- sample_mask_key: The key to the sample mask dataset inside each file.
- unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream.
- supervised_sampler: Sampler to accept or reject patches for the supervised data stream.
Pass
Falseto disable. - check: Whether to check the training and validation loaders instead of running training.