synapse_net.training.domain_adaptation

  1import os
  2import tempfile
  3from glob import glob
  4from pathlib import Path
  5from typing import Optional, Tuple
  6
  7import mrcfile
  8import torch
  9import torch_em
 10import torch_em.self_training as self_training
 11from elf.io import open_file
 12from sklearn.model_selection import train_test_split
 13
 14from .semisupervised_training import get_unsupervised_loader
 15from .supervised_training import (
 16    get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
 17)
 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size, get_available_models
 19from ..inference.util import _Scaler
 20
 21
 22def mean_teacher_adaptation(
 23    name: str,
 24    unsupervised_train_paths: Tuple[str],
 25    unsupervised_val_paths: Tuple[str],
 26    patch_shape: Tuple[int, int, int],
 27    save_root: Optional[str] = None,
 28    source_checkpoint: Optional[str] = None,
 29    supervised_train_paths: Optional[Tuple[str]] = None,
 30    supervised_val_paths: Optional[Tuple[str]] = None,
 31    confidence_threshold: float = 0.9,
 32    raw_key: str = "raw",
 33    raw_key_supervised: str = "raw",
 34    label_key: Optional[str] = None,
 35    batch_size: int = 1,
 36    lr: float = 1e-4,
 37    n_iterations: int = int(1e4),
 38    n_samples_train: Optional[int] = None,
 39    n_samples_val: Optional[int] = None,
 40    train_mask_paths: Optional[Tuple[str]] = None,
 41    val_mask_paths: Optional[Tuple[str]] = None,
 42    patch_sampler: Optional[callable] = None,
 43    pseudo_label_sampler: Optional[callable] = None,
 44    device: int = 0,
 45    check: bool = False,
 46) -> None:
 47    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 48    segmentation task to perform this task on a different target domain.
 49
 50    We support different domain adaptation settings:
 51    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 52     'supervised_val_paths' are not given.
 53    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 54      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 55
 56    Args:
 57        name: The name for the checkpoint to be trained.
 58        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 59            for the training data in the target domain.
 60            This training data is used for unsupervised learning, so it does not require labels.
 61        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 62            for the validation data in the target domain.
 63            This validation data is used for unsupervised learning, so it does not require labels.
 64        patch_shape: The patch shape used for a training example.
 65            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 66            e.g. 'patch_shape = [1, 512, 512]'.
 67        save_root: Folder where the checkpoint will be saved.
 68        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 69            This is used to initialize the teacher model.
 70            If the checkpoint is not given, then both student and teacher model are initialized
 71            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 72            be given in order to provide training data from the source domain.
 73        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 74            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 75        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 76            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 77        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 78            The label filtering is done based on the uncertainty of network predictions, and only
 79            the data with higher certainty than this threshold is used for training.
 80        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 81        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 82            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 83        batch_size: The batch size for training.
 84        lr: The initial learning rate.
 85        n_iterations: The number of iterations to train for.
 86        n_samples_train: The number of train samples per epoch. By default this will be estimated
 87            based on the patch_shape and size of the volumes used for training.
 88        n_samples_val: The number of val samples per epoch. By default this will be estimated
 89            based on the patch_shape and size of the volumes used for validation.
 90        train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
 91        val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
 92        patch_sampler: Accept or reject patches based on a condition.
 93        pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
 94        device: GPU ID for training.
 95        check: Whether to check the training and validation loaders instead of running training.
 96    """  # noqa
 97    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 98    is_2d, _ = _determine_ndim(patch_shape)
 99
100    if source_checkpoint is None:
101        # training from scratch only makes sense if we have supervised training data
102        # that's why we have the assertion here.
103        assert supervised_train_paths is not None
104        print("Mean teacher training from scratch (AdaMT)")
105        if is_2d:
106            model = get_2d_model(out_channels=2)
107        else:
108            model = get_3d_model(out_channels=2)
109        reinit_teacher = True
110    else:
111        print("Mean teacher training initialized from source model:", source_checkpoint)
112        if os.path.isdir(source_checkpoint):
113            model = torch_em.util.load_model(source_checkpoint)
114        else:
115            model = torch.load(source_checkpoint, weights_only=False)
116        reinit_teacher = False
117
118    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
119    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
120
121    # self training functionality
122    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
123    loss = self_training.DefaultSelfTrainingLoss()
124    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
125
126    unsupervised_train_loader = get_unsupervised_loader(
127        data_paths=unsupervised_train_paths,
128        raw_key=raw_key,
129        patch_shape=patch_shape,
130        batch_size=batch_size,
131        n_samples=n_samples_train,
132        sample_mask_paths=train_mask_paths,
133        sampler=patch_sampler
134    )
135    unsupervised_val_loader = get_unsupervised_loader(
136        data_paths=unsupervised_val_paths,
137        raw_key=raw_key,
138        patch_shape=patch_shape,
139        batch_size=batch_size,
140        n_samples=n_samples_val,
141        sample_mask_paths=val_mask_paths,
142        sampler=patch_sampler
143    )
144
145    if supervised_train_paths is not None:
146        assert label_key is not None
147        supervised_train_loader = get_supervised_loader(
148            supervised_train_paths, raw_key_supervised, label_key,
149            patch_shape, batch_size, n_samples=n_samples_train,
150        )
151        supervised_val_loader = get_supervised_loader(
152            supervised_val_paths, raw_key_supervised, label_key,
153            patch_shape, batch_size, n_samples=n_samples_val,
154        )
155    else:
156        supervised_train_loader = None
157        supervised_val_loader = None
158
159    if check:
160        from torch_em.util.debug import check_loader
161        check_loader(unsupervised_train_loader, n_samples=4)
162        check_loader(unsupervised_val_loader, n_samples=4)
163        if supervised_train_loader is not None:
164            check_loader(supervised_train_loader, n_samples=4)
165            check_loader(supervised_val_loader, n_samples=4)
166        return
167
168    device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
169    trainer = self_training.MeanTeacherTrainer(
170        name=name,
171        model=model,
172        optimizer=optimizer,
173        lr_scheduler=scheduler,
174        pseudo_labeler=pseudo_labeler,
175        unsupervised_loss=loss,
176        unsupervised_loss_and_metric=loss_and_metric,
177        supervised_train_loader=supervised_train_loader,
178        unsupervised_train_loader=unsupervised_train_loader,
179        supervised_val_loader=supervised_val_loader,
180        unsupervised_val_loader=unsupervised_val_loader,
181        supervised_loss=loss,
182        supervised_loss_and_metric=loss_and_metric,
183        logger=self_training.SelfTrainingTensorboardLogger,
184        mixed_precision=True,
185        log_image_interval=100,
186        compile_model=False,
187        device=device,
188        reinit_teacher=reinit_teacher,
189        save_root=save_root,
190        sampler=pseudo_label_sampler,
191    )
192    trainer.fit(n_iterations)
193
194
195# TODO patch shapes for other models
196PATCH_SHAPES = {
197    "vesicles_3d": [48, 256, 256],
198}
199"""@private
200"""
201
202
203def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
204    files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
205    if len(files) == 0:
206        raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
207
208    # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
209    # And resave the volumes.
210    resave_val_crops = len(files) < 4
211
212    # We only resave the data if we resave val crops or resize the training data
213    resave_data = resave_val_crops or resize_training_data
214    if not resave_data:
215        train_paths, val_paths = train_test_split(files, test_size=val_fraction)
216        return train_paths, val_paths
217
218    train_paths, val_paths = [], []
219    for file_path in files:
220        file_name = os.path.basename(file_path)
221        data = open_file(file_path, mode="r")["data"][:]
222
223        if resize_training_data:
224            with mrcfile.open(file_path) as f:
225                voxel_size = f.voxel_size
226            voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
227            scale = compute_scale_from_voxel_size(voxel_size, model_name)
228            scaler = _Scaler(scale, verbose=False)
229            data = scaler.sale_input(data)
230
231        if resave_val_crops:
232            n_slices = data.shape[0]
233            val_slice = int((1.0 - val_fraction) * n_slices)
234            train_data, val_data = data[:val_slice], data[val_slice:]
235
236            train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
237            with open_file(train_path, mode="w") as f:
238                f.create_dataset("data", data=train_data, compression="lzf")
239            train_paths.append(train_path)
240
241            val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
242            with open_file(val_path, mode="w") as f:
243                f.create_dataset("data", data=val_data, compression="lzf")
244            val_paths.append(val_path)
245
246        else:
247            output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
248            with open_file(output_path, mode="w") as f:
249                f.create_dataset("data", data=data, compression="lzf")
250            train_paths.append(output_path)
251
252    if not resave_val_crops:
253        train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
254
255    return train_paths, val_paths
256
257
258def _parse_patch_shape(patch_shape, model_name):
259    if patch_shape is None:
260        patch_shape = PATCH_SHAPES[model_name]
261    return patch_shape
262
263
264def main():
265    """@private
266    """
267    import argparse
268
269    parser = argparse.ArgumentParser(
270        description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
271        "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
272        "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n"  # noqa
273        "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)."  # noqa
274        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
275        "Check out the information below for details on the arguments of this function.",
276        formatter_class=argparse.RawTextHelpFormatter
277    )
278    parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
279    parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
280    parser.add_argument("--file_pattern", default="*",
281                        help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
282    parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.")  # noqa
283    available_models = get_available_models()
284    parser.add_argument(
285        "--source_model",
286        default="vesicles_3d",
287        help="The source model used for weight initialization of teacher and student model. "
288        "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used.\n"
289        f"The following source models are available: {available_models}"
290    )
291    parser.add_argument(
292        "--resize_training_data", action="store_true",
293        help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
294    )
295    parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
296    parser.add_argument(
297        "--patch_shape", nargs=3, type=int,
298        help="The patch shape for training. By default the patch shape the source model was trained with is used."
299    )
300
301    # More optional argument:
302    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
303    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
304    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
305    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
306    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
307    parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")
308
309    args = parser.parse_args()
310
311    source_checkpoint = get_model_path(args.source_model)
312    patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
313    with tempfile.TemporaryDirectory() as tmp_dir:
314        unsupervised_train_paths, unsupervised_val_paths = _get_paths(
315            args.input_folder, args.file_pattern, args.resize_training_data,
316            args.source_model, tmp_dir, args.val_fraction,
317        )
318        unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
319
320        mean_teacher_adaptation(
321            name=args.name,
322            unsupervised_train_paths=unsupervised_train_paths,
323            unsupervised_val_paths=unsupervised_val_paths,
324            patch_shape=patch_shape,
325            source_checkpoint=source_checkpoint,
326            raw_key=raw_key,
327            n_iterations=args.n_iterations,
328            batch_size=args.batch_size,
329            n_samples_train=args.n_samples_train,
330            n_samples_val=args.n_samples_val,
331            check=args.check,
332            save_root=args.save_root,
333        )
def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, patch_sampler: Optional[<built-in function callable>] = None, pseudo_label_sampler: Optional[<built-in function callable>] = None, device: int = 0, check: bool = False) -> None:
 23def mean_teacher_adaptation(
 24    name: str,
 25    unsupervised_train_paths: Tuple[str],
 26    unsupervised_val_paths: Tuple[str],
 27    patch_shape: Tuple[int, int, int],
 28    save_root: Optional[str] = None,
 29    source_checkpoint: Optional[str] = None,
 30    supervised_train_paths: Optional[Tuple[str]] = None,
 31    supervised_val_paths: Optional[Tuple[str]] = None,
 32    confidence_threshold: float = 0.9,
 33    raw_key: str = "raw",
 34    raw_key_supervised: str = "raw",
 35    label_key: Optional[str] = None,
 36    batch_size: int = 1,
 37    lr: float = 1e-4,
 38    n_iterations: int = int(1e4),
 39    n_samples_train: Optional[int] = None,
 40    n_samples_val: Optional[int] = None,
 41    train_mask_paths: Optional[Tuple[str]] = None,
 42    val_mask_paths: Optional[Tuple[str]] = None,
 43    patch_sampler: Optional[callable] = None,
 44    pseudo_label_sampler: Optional[callable] = None,
 45    device: int = 0,
 46    check: bool = False,
 47) -> None:
 48    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 49    segmentation task to perform this task on a different target domain.
 50
 51    We support different domain adaptation settings:
 52    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 53     'supervised_val_paths' are not given.
 54    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 55      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 56
 57    Args:
 58        name: The name for the checkpoint to be trained.
 59        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 60            for the training data in the target domain.
 61            This training data is used for unsupervised learning, so it does not require labels.
 62        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 63            for the validation data in the target domain.
 64            This validation data is used for unsupervised learning, so it does not require labels.
 65        patch_shape: The patch shape used for a training example.
 66            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 67            e.g. 'patch_shape = [1, 512, 512]'.
 68        save_root: Folder where the checkpoint will be saved.
 69        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 70            This is used to initialize the teacher model.
 71            If the checkpoint is not given, then both student and teacher model are initialized
 72            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 73            be given in order to provide training data from the source domain.
 74        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 75            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 76        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 77            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 78        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 79            The label filtering is done based on the uncertainty of network predictions, and only
 80            the data with higher certainty than this threshold is used for training.
 81        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 82        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 83            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 84        batch_size: The batch size for training.
 85        lr: The initial learning rate.
 86        n_iterations: The number of iterations to train for.
 87        n_samples_train: The number of train samples per epoch. By default this will be estimated
 88            based on the patch_shape and size of the volumes used for training.
 89        n_samples_val: The number of val samples per epoch. By default this will be estimated
 90            based on the patch_shape and size of the volumes used for validation.
 91        train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
 92        val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
 93        patch_sampler: Accept or reject patches based on a condition.
 94        pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
 95        device: GPU ID for training.
 96        check: Whether to check the training and validation loaders instead of running training.
 97    """  # noqa
 98    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 99    is_2d, _ = _determine_ndim(patch_shape)
100
101    if source_checkpoint is None:
102        # training from scratch only makes sense if we have supervised training data
103        # that's why we have the assertion here.
104        assert supervised_train_paths is not None
105        print("Mean teacher training from scratch (AdaMT)")
106        if is_2d:
107            model = get_2d_model(out_channels=2)
108        else:
109            model = get_3d_model(out_channels=2)
110        reinit_teacher = True
111    else:
112        print("Mean teacher training initialized from source model:", source_checkpoint)
113        if os.path.isdir(source_checkpoint):
114            model = torch_em.util.load_model(source_checkpoint)
115        else:
116            model = torch.load(source_checkpoint, weights_only=False)
117        reinit_teacher = False
118
119    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
120    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
121
122    # self training functionality
123    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
124    loss = self_training.DefaultSelfTrainingLoss()
125    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
126
127    unsupervised_train_loader = get_unsupervised_loader(
128        data_paths=unsupervised_train_paths,
129        raw_key=raw_key,
130        patch_shape=patch_shape,
131        batch_size=batch_size,
132        n_samples=n_samples_train,
133        sample_mask_paths=train_mask_paths,
134        sampler=patch_sampler
135    )
136    unsupervised_val_loader = get_unsupervised_loader(
137        data_paths=unsupervised_val_paths,
138        raw_key=raw_key,
139        patch_shape=patch_shape,
140        batch_size=batch_size,
141        n_samples=n_samples_val,
142        sample_mask_paths=val_mask_paths,
143        sampler=patch_sampler
144    )
145
146    if supervised_train_paths is not None:
147        assert label_key is not None
148        supervised_train_loader = get_supervised_loader(
149            supervised_train_paths, raw_key_supervised, label_key,
150            patch_shape, batch_size, n_samples=n_samples_train,
151        )
152        supervised_val_loader = get_supervised_loader(
153            supervised_val_paths, raw_key_supervised, label_key,
154            patch_shape, batch_size, n_samples=n_samples_val,
155        )
156    else:
157        supervised_train_loader = None
158        supervised_val_loader = None
159
160    if check:
161        from torch_em.util.debug import check_loader
162        check_loader(unsupervised_train_loader, n_samples=4)
163        check_loader(unsupervised_val_loader, n_samples=4)
164        if supervised_train_loader is not None:
165            check_loader(supervised_train_loader, n_samples=4)
166            check_loader(supervised_val_loader, n_samples=4)
167        return
168
169    device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
170    trainer = self_training.MeanTeacherTrainer(
171        name=name,
172        model=model,
173        optimizer=optimizer,
174        lr_scheduler=scheduler,
175        pseudo_labeler=pseudo_labeler,
176        unsupervised_loss=loss,
177        unsupervised_loss_and_metric=loss_and_metric,
178        supervised_train_loader=supervised_train_loader,
179        unsupervised_train_loader=unsupervised_train_loader,
180        supervised_val_loader=supervised_val_loader,
181        unsupervised_val_loader=unsupervised_val_loader,
182        supervised_loss=loss,
183        supervised_loss_and_metric=loss_and_metric,
184        logger=self_training.SelfTrainingTensorboardLogger,
185        mixed_precision=True,
186        log_image_interval=100,
187        compile_model=False,
188        device=device,
189        reinit_teacher=reinit_teacher,
190        save_root=save_root,
191        sampler=pseudo_label_sampler,
192    )
193    trainer.fit(n_iterations)

Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:

  • unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
  • semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
  • name: The name for the checkpoint to be trained.
  • unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
  • unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • source_checkpoint: Checkpoint to the initial model trained on the source domain. This is used to initialize the teacher model. If the checkpoint is not given, then both student and teacher model are initialized from scratch. In this case supervised_train_paths and supervised_val_paths have to be given in order to provide training data from the source domain.
  • supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
  • supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
  • confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
  • raw_key: The key that holds the raw data inside of the hdf5 or similar files.
  • label_key: The key that holds the labels inside of the hdf5 files for supervised learning. This is only required if supervised_train_paths and supervised_val_paths are given.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
  • val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
  • patch_sampler: Accept or reject patches based on a condition.
  • pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
  • device: GPU ID for training.
  • check: Whether to check the training and validation loaders instead of running training.