synapse_net.training.domain_adaptation

  1import os
  2import tempfile
  3from glob import glob
  4from pathlib import Path
  5from typing import Optional, Tuple
  6
  7import mrcfile
  8import torch
  9import torch_em
 10import torch_em.self_training as self_training
 11from elf.io import open_file
 12from sklearn.model_selection import train_test_split
 13
 14from .semisupervised_training import get_unsupervised_loader
 15from .supervised_training import (
 16    get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
 17)
 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models
 19from ..inference.util import _Scaler
 20
 21# configure weak augmentations
 22from torch_em.transform.invertible_augmentations import DEFAULT_WEAK_AUGMENTATIONS
 23
 24# TODO - test settings - fixed kernel size for `RandomGaussianBlur`, fixed std for `RandomGaussianBlur`
 25DEFAULT_WEAK_AUGMENTATIONS["intensity"] = {
 26    "RandomGaussianBlur": {"kernel_size": (19, 19), "sigma": (0.1, 3.0)},
 27    "RandomGaussianNoise": {"mean": (0.0), "std": (0.1)},
 28}    
 29
 30def mean_teacher_adaptation(
 31    name: str,
 32    unsupervised_train_paths: Tuple[str],
 33    unsupervised_val_paths: Tuple[str],
 34    patch_shape: Tuple[int, int, int],
 35    save_root: Optional[str] = None,
 36    source_checkpoint: Optional[str] = None,
 37    supervised_train_paths: Optional[Tuple[str]] = None,
 38    supervised_val_paths: Optional[Tuple[str]] = None,
 39    confidence_threshold: float = 0.9,
 40    raw_key: str = "raw",
 41    raw_key_supervised: str = "raw",
 42    label_key: Optional[str] = None,
 43    batch_size: int = 1,
 44    lr: float = 1e-4,
 45    n_iterations: int = int(1e4),
 46    n_samples_train: Optional[int] = None,
 47    n_samples_val: Optional[int] = None,
 48    train_mask_paths: Optional[Tuple[str]] = None,
 49    val_mask_paths: Optional[Tuple[str]] = None,
 50    sample_mask_key: Optional[str] = None,
 51    unsupervised_sampler: Optional[callable] = None,
 52    supervised_sampler: Optional[callable] = None,
 53    check: bool = False,
 54) -> None:
 55    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 56    segmentation task to perform this task on a different target domain.
 57
 58    We support different domain adaptation settings:
 59    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 60     'supervised_val_paths' are not given.
 61    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 62      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 63    
 64    Args:
 65        name: The name for the checkpoint to be trained.
 66        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 67            for the training data in the target domain.
 68            This training data is used for unsupervised learning, so it does not require labels.
 69        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 70            for the validation data in the target domain.
 71            This validation data is used for unsupervised learning, so it does not require labels.
 72        patch_shape: The patch shape used for a training example.
 73            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 74            e.g. 'patch_shape = [1, 512, 512]'.
 75        save_root: Folder where the checkpoint will be saved.
 76        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 77            This is used to initialize the teacher model.
 78            If the checkpoint is not given, then both student and teacher model are initialized
 79            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 80            be given in order to provide training data from the source domain.
 81        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 82            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 83        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 84            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 85        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 86            The label filtering is done based on the uncertainty of network predictions, and only
 87            the data with higher certainty than this threshold is used for training.
 88        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 89        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 90            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 91        batch_size: The batch size for training.
 92        lr: The initial learning rate.
 93        n_iterations: The number of iterations to train for.
 94        n_samples_train: The number of train samples per epoch. By default this will be estimated
 95            based on the patch_shape and size of the volumes used for training.
 96        n_samples_val: The number of val samples per epoch. By default this will be estimated
 97            based on the patch_shape and size of the volumes used for validation.
 98        train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training.
 99        val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation.
100        sample_mask_key: The key to the sample mask dataset inside each file.
101        unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream.
102        supervised_sampler: Sampler to accept or reject patches for the supervised data stream.
103            Pass `False` to disable.
104        check: Whether to check the training and validation loaders instead of running training.
105    """  # noqa
106    assert (supervised_train_paths is None) == (supervised_val_paths is None)
107    is_2d, _ = _determine_ndim(patch_shape)
108
109    if source_checkpoint is None:
110        # training from scratch only makes sense if we have supervised training data
111        # that's why we have the assertion here.
112        assert supervised_train_paths is not None
113        print("Mean teacher training from scratch (AdaMT)")
114        if is_2d:
115            model = get_2d_model(out_channels=2)
116        else:
117            model = get_3d_model(out_channels=2)
118        reinit_teacher = True
119    else:
120        print("Mean teacher training initialized from source model:", source_checkpoint)
121        if os.path.isdir(source_checkpoint):
122            model = torch_em.util.load_model(source_checkpoint)
123        else:
124            model = torch.load(source_checkpoint, weights_only=False)
125        reinit_teacher = False
126
127    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
128    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
129
130    # self training functionality
131    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
132    loss = self_training.SelfTrainingLossWithInvertibleAugmentations()
133    loss_and_metric = self_training.SelfTrainingLossAndMetricWithInvertibleAugmentations()
134
135    ndim = 2 if is_2d else 3
136    augmenters = torch_em.transform.invertible_augmentations.MeanTeacherAugmenters(ndim=ndim)
137
138    unsupervised_train_loader = get_unsupervised_loader(
139        data_paths=unsupervised_train_paths,
140        raw_key=raw_key,
141        patch_shape=patch_shape,
142        batch_size=batch_size,
143        n_samples=n_samples_train,
144        sample_mask_paths=train_mask_paths,
145        sample_mask_key=sample_mask_key,
146        sampler=unsupervised_sampler,
147    )
148    unsupervised_val_loader = get_unsupervised_loader(
149        data_paths=unsupervised_val_paths,
150        raw_key=raw_key,
151        patch_shape=patch_shape,
152        batch_size=batch_size,
153        n_samples=n_samples_val,
154        sample_mask_paths=val_mask_paths,
155        sample_mask_key=sample_mask_key,
156        sampler=unsupervised_sampler,
157    )
158
159    if supervised_train_paths is not None:
160        assert label_key is not None
161        supervised_train_loader = get_supervised_loader(
162            supervised_train_paths, raw_key_supervised, label_key,
163            patch_shape, batch_size, n_samples=n_samples_train,
164            sampler=supervised_sampler,
165        )
166        supervised_val_loader = get_supervised_loader(
167            supervised_val_paths, raw_key_supervised, label_key,
168            patch_shape, batch_size, n_samples=n_samples_val,
169            sampler=supervised_sampler,
170        )
171    else:
172        supervised_train_loader = None
173        supervised_val_loader = None
174
175    if check:
176        from torch_em.util.debug import check_loader
177        check_loader(unsupervised_train_loader, n_samples=4)
178        check_loader(unsupervised_val_loader, n_samples=4)
179        if supervised_train_loader is not None:
180            check_loader(supervised_train_loader, n_samples=4)
181            check_loader(supervised_val_loader, n_samples=4)
182        return
183
184    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
185    trainer = self_training.MeanTeacherTrainerWithInvertibleAugmentations(
186        name=name,
187        model=model,
188        optimizer=optimizer,
189        lr_scheduler=scheduler,
190        pseudo_labeler=pseudo_labeler,
191        unsupervised_loss=loss,
192        unsupervised_loss_and_metric=loss_and_metric,
193        supervised_train_loader=supervised_train_loader,
194        unsupervised_train_loader=unsupervised_train_loader,
195        supervised_val_loader=supervised_val_loader,
196        unsupervised_val_loader=unsupervised_val_loader,
197        supervised_loss=loss,
198        supervised_loss_and_metric=loss_and_metric,
199        logger=self_training.SelfTrainingTensorboardLogger,
200        mixed_precision=True,
201        log_image_interval=100,
202        compile_model=False,
203        device=device,
204        reinit_teacher=reinit_teacher,
205        save_root=save_root,
206        augmenter=augmenters,
207    )
208    trainer.fit(n_iterations)
209
210
211# TODO patch shapes for other models
212PATCH_SHAPES = {
213    "vesicles_3d": [48, 256, 256],
214}
215"""@private
216"""
217
218
219def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
220    files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
221    if len(files) == 0:
222        raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
223
224    # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
225    # And resave the volumes.
226    resave_val_crops = len(files) < 4
227
228    # We only resave the data if we resave val crops or resize the training data
229    resave_data = resave_val_crops or resize_training_data
230    if not resave_data:
231        train_paths, val_paths = train_test_split(files, test_size=val_fraction)
232        return train_paths, val_paths
233
234    train_paths, val_paths = [], []
235    for file_path in files:
236        file_name = os.path.basename(file_path)
237        data = open_file(file_path, mode="r")["data"][:]
238
239        if resize_training_data:
240            with mrcfile.open(file_path) as f:
241                voxel_size = f.voxel_size
242            voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
243            scale = compute_scale_from_voxel_size(voxel_size, model_name)
244            scaler = _Scaler(scale, verbose=False)
245            data = scaler.sale_input(data)
246
247        if resave_val_crops:
248            n_slices = data.shape[0]
249            val_slice = int((1.0 - val_fraction) * n_slices)
250            train_data, val_data = data[:val_slice], data[val_slice:]
251
252            train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
253            with open_file(train_path, mode="w") as f:
254                f.create_dataset("data", data=train_data, compression="lzf")
255            train_paths.append(train_path)
256
257            val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
258            with open_file(val_path, mode="w") as f:
259                f.create_dataset("data", data=val_data, compression="lzf")
260            val_paths.append(val_path)
261
262        else:
263            output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
264            with open_file(output_path, mode="w") as f:
265                f.create_dataset("data", data=data, compression="lzf")
266            train_paths.append(output_path)
267
268    if not resave_val_crops:
269        train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
270
271    return train_paths, val_paths
272
273
274def _parse_patch_shape(patch_shape, model_name):
275    if patch_shape is None:
276        patch_shape = PATCH_SHAPES[model_name]
277    return patch_shape
278
279
280def main():
281    """@private
282    """
283    import argparse
284
285    parser = argparse.ArgumentParser(
286        description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
287        "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
288        "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n"  # noqa
289        "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)."  # noqa
290        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
291        "Check out the information below for details on the arguments of this function.",
292        formatter_class=argparse.RawTextHelpFormatter
293    )
294    parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
295    parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
296    parser.add_argument("--file_pattern", default="*",
297                        help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
298    parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.")  # noqa
299    available_models = get_available_models()
300    parser.add_argument(
301        "--source_model",
302        default="vesicles_3d",
303        help="The source model used for weight initialization of teacher and student model. "
304        "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n"
305        f"The following source models are available: {available_models}"
306    )
307    parser.add_argument(
308        "--resize_training_data", action="store_true",
309        help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
310    )
311    parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
312    parser.add_argument(
313        "--patch_shape", nargs=3, type=int,
314        help="The patch shape for training. By default the patch shape the source model was trained with is used."
315    )
316
317    # More optional argument:
318    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
319    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
320    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
321    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
322    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
323    parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")
324
325    args = parser.parse_args()
326
327    source_checkpoint = get_model_path(args.source_model)
328    patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
329    with tempfile.TemporaryDirectory() as tmp_dir:
330        unsupervised_train_paths, unsupervised_val_paths = _get_paths(
331            args.input_folder, args.file_pattern, args.resize_training_data,
332            args.source_model, tmp_dir, args.val_fraction,
333        )
334        unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
335
336        mean_teacher_adaptation(
337            name=args.name,
338            unsupervised_train_paths=unsupervised_train_paths,
339            unsupervised_val_paths=unsupervised_val_paths,
340            patch_shape=patch_shape,
341            source_checkpoint=source_checkpoint,
342            raw_key=raw_key,
343            n_iterations=args.n_iterations,
344            batch_size=args.batch_size,
345            n_samples_train=args.n_samples_train,
346            n_samples_val=args.n_samples_val,
347            check=args.check,
348            save_root=args.save_root,
349        )
def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, sample_mask_key: Optional[str] = None, unsupervised_sampler: Optional[<built-in function callable>] = None, supervised_sampler: Optional[<built-in function callable>] = None, check: bool = False) -> None:
 31def mean_teacher_adaptation(
 32    name: str,
 33    unsupervised_train_paths: Tuple[str],
 34    unsupervised_val_paths: Tuple[str],
 35    patch_shape: Tuple[int, int, int],
 36    save_root: Optional[str] = None,
 37    source_checkpoint: Optional[str] = None,
 38    supervised_train_paths: Optional[Tuple[str]] = None,
 39    supervised_val_paths: Optional[Tuple[str]] = None,
 40    confidence_threshold: float = 0.9,
 41    raw_key: str = "raw",
 42    raw_key_supervised: str = "raw",
 43    label_key: Optional[str] = None,
 44    batch_size: int = 1,
 45    lr: float = 1e-4,
 46    n_iterations: int = int(1e4),
 47    n_samples_train: Optional[int] = None,
 48    n_samples_val: Optional[int] = None,
 49    train_mask_paths: Optional[Tuple[str]] = None,
 50    val_mask_paths: Optional[Tuple[str]] = None,
 51    sample_mask_key: Optional[str] = None,
 52    unsupervised_sampler: Optional[callable] = None,
 53    supervised_sampler: Optional[callable] = None,
 54    check: bool = False,
 55) -> None:
 56    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 57    segmentation task to perform this task on a different target domain.
 58
 59    We support different domain adaptation settings:
 60    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 61     'supervised_val_paths' are not given.
 62    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 63      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 64    
 65    Args:
 66        name: The name for the checkpoint to be trained.
 67        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 68            for the training data in the target domain.
 69            This training data is used for unsupervised learning, so it does not require labels.
 70        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 71            for the validation data in the target domain.
 72            This validation data is used for unsupervised learning, so it does not require labels.
 73        patch_shape: The patch shape used for a training example.
 74            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 75            e.g. 'patch_shape = [1, 512, 512]'.
 76        save_root: Folder where the checkpoint will be saved.
 77        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 78            This is used to initialize the teacher model.
 79            If the checkpoint is not given, then both student and teacher model are initialized
 80            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 81            be given in order to provide training data from the source domain.
 82        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 83            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 84        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 85            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 86        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 87            The label filtering is done based on the uncertainty of network predictions, and only
 88            the data with higher certainty than this threshold is used for training.
 89        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 90        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 91            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 92        batch_size: The batch size for training.
 93        lr: The initial learning rate.
 94        n_iterations: The number of iterations to train for.
 95        n_samples_train: The number of train samples per epoch. By default this will be estimated
 96            based on the patch_shape and size of the volumes used for training.
 97        n_samples_val: The number of val samples per epoch. By default this will be estimated
 98            based on the patch_shape and size of the volumes used for validation.
 99        train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training.
100        val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation.
101        sample_mask_key: The key to the sample mask dataset inside each file.
102        unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream.
103        supervised_sampler: Sampler to accept or reject patches for the supervised data stream.
104            Pass `False` to disable.
105        check: Whether to check the training and validation loaders instead of running training.
106    """  # noqa
107    assert (supervised_train_paths is None) == (supervised_val_paths is None)
108    is_2d, _ = _determine_ndim(patch_shape)
109
110    if source_checkpoint is None:
111        # training from scratch only makes sense if we have supervised training data
112        # that's why we have the assertion here.
113        assert supervised_train_paths is not None
114        print("Mean teacher training from scratch (AdaMT)")
115        if is_2d:
116            model = get_2d_model(out_channels=2)
117        else:
118            model = get_3d_model(out_channels=2)
119        reinit_teacher = True
120    else:
121        print("Mean teacher training initialized from source model:", source_checkpoint)
122        if os.path.isdir(source_checkpoint):
123            model = torch_em.util.load_model(source_checkpoint)
124        else:
125            model = torch.load(source_checkpoint, weights_only=False)
126        reinit_teacher = False
127
128    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
129    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
130
131    # self training functionality
132    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
133    loss = self_training.SelfTrainingLossWithInvertibleAugmentations()
134    loss_and_metric = self_training.SelfTrainingLossAndMetricWithInvertibleAugmentations()
135
136    ndim = 2 if is_2d else 3
137    augmenters = torch_em.transform.invertible_augmentations.MeanTeacherAugmenters(ndim=ndim)
138
139    unsupervised_train_loader = get_unsupervised_loader(
140        data_paths=unsupervised_train_paths,
141        raw_key=raw_key,
142        patch_shape=patch_shape,
143        batch_size=batch_size,
144        n_samples=n_samples_train,
145        sample_mask_paths=train_mask_paths,
146        sample_mask_key=sample_mask_key,
147        sampler=unsupervised_sampler,
148    )
149    unsupervised_val_loader = get_unsupervised_loader(
150        data_paths=unsupervised_val_paths,
151        raw_key=raw_key,
152        patch_shape=patch_shape,
153        batch_size=batch_size,
154        n_samples=n_samples_val,
155        sample_mask_paths=val_mask_paths,
156        sample_mask_key=sample_mask_key,
157        sampler=unsupervised_sampler,
158    )
159
160    if supervised_train_paths is not None:
161        assert label_key is not None
162        supervised_train_loader = get_supervised_loader(
163            supervised_train_paths, raw_key_supervised, label_key,
164            patch_shape, batch_size, n_samples=n_samples_train,
165            sampler=supervised_sampler,
166        )
167        supervised_val_loader = get_supervised_loader(
168            supervised_val_paths, raw_key_supervised, label_key,
169            patch_shape, batch_size, n_samples=n_samples_val,
170            sampler=supervised_sampler,
171        )
172    else:
173        supervised_train_loader = None
174        supervised_val_loader = None
175
176    if check:
177        from torch_em.util.debug import check_loader
178        check_loader(unsupervised_train_loader, n_samples=4)
179        check_loader(unsupervised_val_loader, n_samples=4)
180        if supervised_train_loader is not None:
181            check_loader(supervised_train_loader, n_samples=4)
182            check_loader(supervised_val_loader, n_samples=4)
183        return
184
185    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
186    trainer = self_training.MeanTeacherTrainerWithInvertibleAugmentations(
187        name=name,
188        model=model,
189        optimizer=optimizer,
190        lr_scheduler=scheduler,
191        pseudo_labeler=pseudo_labeler,
192        unsupervised_loss=loss,
193        unsupervised_loss_and_metric=loss_and_metric,
194        supervised_train_loader=supervised_train_loader,
195        unsupervised_train_loader=unsupervised_train_loader,
196        supervised_val_loader=supervised_val_loader,
197        unsupervised_val_loader=unsupervised_val_loader,
198        supervised_loss=loss,
199        supervised_loss_and_metric=loss_and_metric,
200        logger=self_training.SelfTrainingTensorboardLogger,
201        mixed_precision=True,
202        log_image_interval=100,
203        compile_model=False,
204        device=device,
205        reinit_teacher=reinit_teacher,
206        save_root=save_root,
207        augmenter=augmenters,
208    )
209    trainer.fit(n_iterations)

Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:

  • unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
  • semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
  • name: The name for the checkpoint to be trained.
  • unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
  • unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • source_checkpoint: Checkpoint to the initial model trained on the source domain. This is used to initialize the teacher model. If the checkpoint is not given, then both student and teacher model are initialized from scratch. In this case supervised_train_paths and supervised_val_paths have to be given in order to provide training data from the source domain.
  • supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
  • supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
  • confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
  • raw_key: The key that holds the raw data inside of the hdf5 or similar files.
  • label_key: The key that holds the labels inside of the hdf5 files for supervised learning. This is only required if supervised_train_paths and supervised_val_paths are given.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • train_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for training.
  • val_mask_paths: Sample masks used by the unsupervised sampler to accept or reject patches for validation.
  • sample_mask_key: The key to the sample mask dataset inside each file.
  • unsupervised_sampler: Sampler to accept or reject patches for the unsupervised data stream.
  • supervised_sampler: Sampler to accept or reject patches for the supervised data stream. Pass False to disable.
  • check: Whether to check the training and validation loaders instead of running training.