synapse_net.training.domain_adaptation
1import os 2import tempfile 3from glob import glob 4from pathlib import Path 5from typing import Optional, Tuple 6 7import mrcfile 8import torch 9import torch_em 10import torch_em.self_training as self_training 11from elf.io import open_file 12from sklearn.model_selection import train_test_split 13 14from .semisupervised_training import get_unsupervised_loader 15from .supervised_training import ( 16 get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files 17) 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size 19from ..inference.util import _Scaler 20 21 22def mean_teacher_adaptation( 23 name: str, 24 unsupervised_train_paths: Tuple[str], 25 unsupervised_val_paths: Tuple[str], 26 patch_shape: Tuple[int, int, int], 27 save_root: Optional[str] = None, 28 source_checkpoint: Optional[str] = None, 29 supervised_train_paths: Optional[Tuple[str]] = None, 30 supervised_val_paths: Optional[Tuple[str]] = None, 31 confidence_threshold: float = 0.9, 32 raw_key: str = "raw", 33 raw_key_supervised: str = "raw", 34 label_key: Optional[str] = None, 35 batch_size: int = 1, 36 lr: float = 1e-4, 37 n_iterations: int = int(1e4), 38 n_samples_train: Optional[int] = None, 39 n_samples_val: Optional[int] = None, 40 sampler: Optional[callable] = None, 41) -> None: 42 """Run domain adapation to transfer a network trained on a source domain for a supervised 43 segmentation task to perform this task on a different target domain. 44 45 We support different domain adaptation settings: 46 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 47 'supervised_val_paths' are not given. 48 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 49 when 'supervised_train_paths' and 'supervised_val_paths' are given. 50 51 Args: 52 name: The name for the checkpoint to be trained. 53 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 54 for the training data in the target domain. 55 This training data is used for unsupervised learning, so it does not require labels. 56 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 57 for the validation data in the target domain. 58 This validation data is used for unsupervised learning, so it does not require labels. 59 patch_shape: The patch shape used for a training example. 60 In order to run 2d training pass a patch shape with a singleton in the z-axis, 61 e.g. 'patch_shape = [1, 512, 512]'. 62 save_root: Folder where the checkpoint will be saved. 63 source_checkpoint: Checkpoint to the initial model trained on the source domain. 64 This is used to initialize the teacher model. 65 If the checkpoint is not given, then both student and teacher model are initialized 66 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 67 be given in order to provide training data from the source domain. 68 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 69 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 70 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 71 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 72 confidence_threshold: The threshold for filtering data in the unsupervised loss. 73 The label filtering is done based on the uncertainty of network predictions, and only 74 the data with higher certainty than this threshold is used for training. 75 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 76 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 77 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 78 batch_size: The batch size for training. 79 lr: The initial learning rate. 80 n_iterations: The number of iterations to train for. 81 n_samples_train: The number of train samples per epoch. By default this will be estimated 82 based on the patch_shape and size of the volumes used for training. 83 n_samples_val: The number of val samples per epoch. By default this will be estimated 84 based on the patch_shape and size of the volumes used for validation. 85 """ 86 assert (supervised_train_paths is None) == (supervised_val_paths is None) 87 is_2d, _ = _determine_ndim(patch_shape) 88 89 if source_checkpoint is None: 90 # training from scratch only makes sense if we have supervised training data 91 # that's why we have the assertion here. 92 assert supervised_train_paths is not None 93 print("Mean teacher training from scratch (AdaMT)") 94 if is_2d: 95 model = get_2d_model(out_channels=2) 96 else: 97 model = get_3d_model(out_channels=2) 98 reinit_teacher = True 99 else: 100 print("Mean teacehr training initialized from source model:", source_checkpoint) 101 if os.path.isdir(source_checkpoint): 102 model = torch_em.util.load_model(source_checkpoint) 103 else: 104 model = torch.load(source_checkpoint, weights_only=False) 105 reinit_teacher = False 106 107 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 108 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 109 110 # self training functionality 111 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 112 loss = self_training.DefaultSelfTrainingLoss() 113 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 114 115 unsupervised_train_loader = get_unsupervised_loader( 116 unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train 117 ) 118 unsupervised_val_loader = get_unsupervised_loader( 119 unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val 120 ) 121 122 if supervised_train_paths is not None: 123 assert label_key is not None 124 supervised_train_loader = get_supervised_loader( 125 supervised_train_paths, raw_key_supervised, label_key, 126 patch_shape, batch_size, n_samples=n_samples_train, 127 ) 128 supervised_val_loader = get_supervised_loader( 129 supervised_val_paths, raw_key_supervised, label_key, 130 patch_shape, batch_size, n_samples=n_samples_val, 131 ) 132 else: 133 supervised_train_loader = None 134 supervised_val_loader = None 135 136 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 137 trainer = self_training.MeanTeacherTrainer( 138 name=name, 139 model=model, 140 optimizer=optimizer, 141 lr_scheduler=scheduler, 142 pseudo_labeler=pseudo_labeler, 143 unsupervised_loss=loss, 144 unsupervised_loss_and_metric=loss_and_metric, 145 supervised_train_loader=supervised_train_loader, 146 unsupervised_train_loader=unsupervised_train_loader, 147 supervised_val_loader=supervised_val_loader, 148 unsupervised_val_loader=unsupervised_val_loader, 149 supervised_loss=loss, 150 supervised_loss_and_metric=loss_and_metric, 151 logger=self_training.SelfTrainingTensorboardLogger, 152 mixed_precision=True, 153 log_image_interval=100, 154 compile_model=False, 155 device=device, 156 reinit_teacher=reinit_teacher, 157 save_root=save_root, 158 sampler=sampler, 159 ) 160 trainer.fit(n_iterations) 161 162 163# TODO patch shapes for other models 164PATCH_SHAPES = { 165 "vesicles_3d": [48, 256, 256], 166} 167"""@private 168""" 169 170 171def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction): 172 files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True)) 173 if len(files) == 0: 174 raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}") 175 176 # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation. 177 # And resave the volumes. 178 resave_val_crops = len(files) < 4 179 180 # We only resave the data if we resave val crops or resize the training data 181 resave_data = resave_val_crops or resize_training_data 182 if not resave_data: 183 train_paths, val_paths = train_test_split(files, test_size=val_fraction) 184 return train_paths, val_paths 185 186 train_paths, val_paths = [], [] 187 for file_path in files: 188 file_name = os.path.basename(file_path) 189 data = open_file(file_path, mode="r")["data"][:] 190 191 if resize_training_data: 192 with mrcfile.open(file_path) as f: 193 voxel_size = f.voxel_size 194 voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())} 195 scale = compute_scale_from_voxel_size(voxel_size, model_name) 196 scaler = _Scaler(scale, verbose=False) 197 data = scaler.sale_input(data) 198 199 if resave_val_crops: 200 n_slices = data.shape[0] 201 val_slice = int((1.0 - val_fraction) * n_slices) 202 train_data, val_data = data[:val_slice], data[val_slice:] 203 204 train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5") 205 with open_file(train_path, mode="w") as f: 206 f.create_dataset("data", data=train_data, compression="lzf") 207 train_paths.append(train_path) 208 209 val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5") 210 with open_file(val_path, mode="w") as f: 211 f.create_dataset("data", data=val_data, compression="lzf") 212 val_paths.append(val_path) 213 214 else: 215 output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")) 216 with open_file(output_path, mode="w") as f: 217 f.create_dataset("data", data=data, compression="lzf") 218 train_paths.append(output_path) 219 220 if not resave_val_crops: 221 train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction) 222 223 return train_paths, val_paths 224 225 226def _parse_patch_shape(patch_shape, model_name): 227 if patch_shape is None: 228 patch_shape = PATCH_SHAPES[model_name] 229 return patch_shape 230 231 232def main(): 233 """@private 234 """ 235 import argparse 236 237 parser = argparse.ArgumentParser( 238 description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n" 239 "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n" 240 "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n" # noqa 241 "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)." # noqa 242 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 243 "Check out the information below for details on the arguments of this function.", 244 formatter_class=argparse.RawTextHelpFormatter 245 ) 246 parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ") 247 parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.") 248 parser.add_argument("--file_pattern", default="*", 249 help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.") 250 parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.") # noqa 251 parser.add_argument( 252 "--source_model", 253 default="vesicles_3d", 254 help="The source model used for weight initialization of teacher and student model. " 255 "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used." 256 ) 257 parser.add_argument( 258 "--resize_training_data", action="store_true", 259 help="Whether to resize the training data to fit the voxel size of the source model's trainign data." 260 ) 261 parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.") 262 parser.add_argument( 263 "--patch_shape", nargs=3, type=int, 264 help="The patch shape for training. By default the patch shape the source model was trained with is used." 265 ) 266 267 # More optional argument: 268 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 269 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 270 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 271 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 272 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 273 274 args = parser.parse_args() 275 276 source_checkpoint = get_model_path(args.source_model) 277 patch_shape = _parse_patch_shape(args.patch_shape, args.source_model) 278 with tempfile.TemporaryDirectory() as tmp_dir: 279 unsupervised_train_paths, unsupervised_val_paths = _get_paths( 280 args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction, 281 ) 282 unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key) 283 284 mean_teacher_adaptation( 285 name=args.name, 286 unsupervised_train_paths=unsupervised_train_paths, 287 unsupervised_val_paths=unsupervised_val_paths, 288 patch_shape=patch_shape, 289 source_checkpoint=source_checkpoint, 290 raw_key=raw_key, 291 n_iterations=args.n_iterations, 292 batch_size=args.batch_size, 293 n_samples_train=args.n_samples_train, 294 n_samples_val=args.n_samples_val, 295 check=args.check, 296 )
def
mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, sampler: Optional[<built-in function callable>] = None) -> None:
23def mean_teacher_adaptation( 24 name: str, 25 unsupervised_train_paths: Tuple[str], 26 unsupervised_val_paths: Tuple[str], 27 patch_shape: Tuple[int, int, int], 28 save_root: Optional[str] = None, 29 source_checkpoint: Optional[str] = None, 30 supervised_train_paths: Optional[Tuple[str]] = None, 31 supervised_val_paths: Optional[Tuple[str]] = None, 32 confidence_threshold: float = 0.9, 33 raw_key: str = "raw", 34 raw_key_supervised: str = "raw", 35 label_key: Optional[str] = None, 36 batch_size: int = 1, 37 lr: float = 1e-4, 38 n_iterations: int = int(1e4), 39 n_samples_train: Optional[int] = None, 40 n_samples_val: Optional[int] = None, 41 sampler: Optional[callable] = None, 42) -> None: 43 """Run domain adapation to transfer a network trained on a source domain for a supervised 44 segmentation task to perform this task on a different target domain. 45 46 We support different domain adaptation settings: 47 - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 48 'supervised_val_paths' are not given. 49 - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, 50 when 'supervised_train_paths' and 'supervised_val_paths' are given. 51 52 Args: 53 name: The name for the checkpoint to be trained. 54 unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats 55 for the training data in the target domain. 56 This training data is used for unsupervised learning, so it does not require labels. 57 unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats 58 for the validation data in the target domain. 59 This validation data is used for unsupervised learning, so it does not require labels. 60 patch_shape: The patch shape used for a training example. 61 In order to run 2d training pass a patch shape with a singleton in the z-axis, 62 e.g. 'patch_shape = [1, 512, 512]'. 63 save_root: Folder where the checkpoint will be saved. 64 source_checkpoint: Checkpoint to the initial model trained on the source domain. 65 This is used to initialize the teacher model. 66 If the checkpoint is not given, then both student and teacher model are initialized 67 from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to 68 be given in order to provide training data from the source domain. 69 supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. 70 This training data is optional. If given, it is used for unsupervised learnig and requires labels. 71 supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. 72 This validation data is optional. If given, it is used for unsupervised learnig and requires labels. 73 confidence_threshold: The threshold for filtering data in the unsupervised loss. 74 The label filtering is done based on the uncertainty of network predictions, and only 75 the data with higher certainty than this threshold is used for training. 76 raw_key: The key that holds the raw data inside of the hdf5 or similar files. 77 label_key: The key that holds the labels inside of the hdf5 files for supervised learning. 78 This is only required if `supervised_train_paths` and `supervised_val_paths` are given. 79 batch_size: The batch size for training. 80 lr: The initial learning rate. 81 n_iterations: The number of iterations to train for. 82 n_samples_train: The number of train samples per epoch. By default this will be estimated 83 based on the patch_shape and size of the volumes used for training. 84 n_samples_val: The number of val samples per epoch. By default this will be estimated 85 based on the patch_shape and size of the volumes used for validation. 86 """ 87 assert (supervised_train_paths is None) == (supervised_val_paths is None) 88 is_2d, _ = _determine_ndim(patch_shape) 89 90 if source_checkpoint is None: 91 # training from scratch only makes sense if we have supervised training data 92 # that's why we have the assertion here. 93 assert supervised_train_paths is not None 94 print("Mean teacher training from scratch (AdaMT)") 95 if is_2d: 96 model = get_2d_model(out_channels=2) 97 else: 98 model = get_3d_model(out_channels=2) 99 reinit_teacher = True 100 else: 101 print("Mean teacehr training initialized from source model:", source_checkpoint) 102 if os.path.isdir(source_checkpoint): 103 model = torch_em.util.load_model(source_checkpoint) 104 else: 105 model = torch.load(source_checkpoint, weights_only=False) 106 reinit_teacher = False 107 108 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 109 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 110 111 # self training functionality 112 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold) 113 loss = self_training.DefaultSelfTrainingLoss() 114 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 115 116 unsupervised_train_loader = get_unsupervised_loader( 117 unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train 118 ) 119 unsupervised_val_loader = get_unsupervised_loader( 120 unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val 121 ) 122 123 if supervised_train_paths is not None: 124 assert label_key is not None 125 supervised_train_loader = get_supervised_loader( 126 supervised_train_paths, raw_key_supervised, label_key, 127 patch_shape, batch_size, n_samples=n_samples_train, 128 ) 129 supervised_val_loader = get_supervised_loader( 130 supervised_val_paths, raw_key_supervised, label_key, 131 patch_shape, batch_size, n_samples=n_samples_val, 132 ) 133 else: 134 supervised_train_loader = None 135 supervised_val_loader = None 136 137 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 138 trainer = self_training.MeanTeacherTrainer( 139 name=name, 140 model=model, 141 optimizer=optimizer, 142 lr_scheduler=scheduler, 143 pseudo_labeler=pseudo_labeler, 144 unsupervised_loss=loss, 145 unsupervised_loss_and_metric=loss_and_metric, 146 supervised_train_loader=supervised_train_loader, 147 unsupervised_train_loader=unsupervised_train_loader, 148 supervised_val_loader=supervised_val_loader, 149 unsupervised_val_loader=unsupervised_val_loader, 150 supervised_loss=loss, 151 supervised_loss_and_metric=loss_and_metric, 152 logger=self_training.SelfTrainingTensorboardLogger, 153 mixed_precision=True, 154 log_image_interval=100, 155 compile_model=False, 156 device=device, 157 reinit_teacher=reinit_teacher, 158 save_root=save_root, 159 sampler=sampler, 160 ) 161 trainer.fit(n_iterations)
Run domain adapation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.
We support different domain adaptation settings:
- unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
- semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
- name: The name for the checkpoint to be trained.
- unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
- unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- source_checkpoint: Checkpoint to the initial model trained on the source domain.
This is used to initialize the teacher model.
If the checkpoint is not given, then both student and teacher model are initialized
from scratch. In this case
supervised_train_paths
andsupervised_val_paths
have to be given in order to provide training data from the source domain. - supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
- supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
- confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
- raw_key: The key that holds the raw data inside of the hdf5 or similar files.
- label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
This is only required if
supervised_train_paths
andsupervised_val_paths
are given. - batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.