synapse_net.training.domain_adaptation

  1import os
  2import tempfile
  3from glob import glob
  4from pathlib import Path
  5from typing import Optional, Tuple
  6
  7import mrcfile
  8import torch
  9import torch_em
 10import torch_em.self_training as self_training
 11from elf.io import open_file
 12from sklearn.model_selection import train_test_split
 13
 14from .semisupervised_training import get_unsupervised_loader
 15from .supervised_training import (
 16    get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
 17)
 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size
 19from ..inference.util import _Scaler
 20
 21def mean_teacher_adaptation(
 22    name: str,
 23    unsupervised_train_paths: Tuple[str],
 24    unsupervised_val_paths: Tuple[str],
 25    patch_shape: Tuple[int, int, int],
 26    save_root: Optional[str] = None,
 27    source_checkpoint: Optional[str] = None,
 28    supervised_train_paths: Optional[Tuple[str]] = None,
 29    supervised_val_paths: Optional[Tuple[str]] = None,
 30    confidence_threshold: float = 0.9,
 31    raw_key: str = "raw",
 32    raw_key_supervised: str = "raw",
 33    label_key: Optional[str] = None,
 34    batch_size: int = 1,
 35    lr: float = 1e-4,
 36    n_iterations: int = int(1e4),
 37    n_samples_train: Optional[int] = None,
 38    n_samples_val: Optional[int] = None,
 39    train_mask_paths: Optional[Tuple[str]] = None,
 40    val_mask_paths: Optional[Tuple[str]] = None,
 41    patch_sampler: Optional[callable] = None,
 42    pseudo_label_sampler: Optional[callable] = None,
 43    device: int = 0,
 44) -> None:
 45    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 46    segmentation task to perform this task on a different target domain.
 47
 48    We support different domain adaptation settings:
 49    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 50     'supervised_val_paths' are not given.
 51    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 52      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 53
 54    Args:
 55        name: The name for the checkpoint to be trained.
 56        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 57            for the training data in the target domain.
 58            This training data is used for unsupervised learning, so it does not require labels.
 59        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 60            for the validation data in the target domain.
 61            This validation data is used for unsupervised learning, so it does not require labels.
 62        patch_shape: The patch shape used for a training example.
 63            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 64            e.g. 'patch_shape = [1, 512, 512]'.
 65        save_root: Folder where the checkpoint will be saved.
 66        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 67            This is used to initialize the teacher model.
 68            If the checkpoint is not given, then both student and teacher model are initialized
 69            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 70            be given in order to provide training data from the source domain.
 71        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 72            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 73        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 74            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 75        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 76            The label filtering is done based on the uncertainty of network predictions, and only
 77            the data with higher certainty than this threshold is used for training.
 78        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 79        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 80            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 81        batch_size: The batch size for training.
 82        lr: The initial learning rate.
 83        n_iterations: The number of iterations to train for.
 84        n_samples_train: The number of train samples per epoch. By default this will be estimated
 85            based on the patch_shape and size of the volumes used for training.
 86        n_samples_val: The number of val samples per epoch. By default this will be estimated
 87            based on the patch_shape and size of the volumes used for validation.
 88        train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 
 89        val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 
 90        patch_sampler: Accept or reject patches based on a condition.
 91        pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 
 92        device: GPU ID for training. 
 93    """
 94    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 95    is_2d, _ = _determine_ndim(patch_shape)
 96
 97    if source_checkpoint is None:
 98        # training from scratch only makes sense if we have supervised training data
 99        # that's why we have the assertion here.
100        assert supervised_train_paths is not None
101        print("Mean teacher training from scratch (AdaMT)")
102        if is_2d:
103            model = get_2d_model(out_channels=2)
104        else:
105            model = get_3d_model(out_channels=2)
106        reinit_teacher = True
107    else:
108        print("Mean teacher training initialized from source model:", source_checkpoint)
109        if os.path.isdir(source_checkpoint):
110            model = torch_em.util.load_model(source_checkpoint)
111        else:
112            model = torch.load(source_checkpoint, weights_only=False)
113        reinit_teacher = False
114
115    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
116    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
117
118    # self training functionality
119    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
120    loss = self_training.DefaultSelfTrainingLoss()
121    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
122    
123    unsupervised_train_loader = get_unsupervised_loader(
124        data_paths=unsupervised_train_paths, 
125        raw_key=raw_key, 
126        patch_shape=patch_shape, 
127        batch_size=batch_size, 
128        n_samples=n_samples_train, 
129        sample_mask_paths=train_mask_paths, 
130        sampler=patch_sampler
131    )
132    unsupervised_val_loader = get_unsupervised_loader(
133        data_paths=unsupervised_val_paths, 
134        raw_key=raw_key, 
135        patch_shape=patch_shape, 
136        batch_size=batch_size, 
137        n_samples=n_samples_val, 
138        sample_mask_paths=val_mask_paths, 
139        sampler=patch_sampler
140    )
141
142    if supervised_train_paths is not None:
143        assert label_key is not None
144        supervised_train_loader = get_supervised_loader(
145            supervised_train_paths, raw_key_supervised, label_key,
146            patch_shape, batch_size, n_samples=n_samples_train,
147        )
148        supervised_val_loader = get_supervised_loader(
149            supervised_val_paths, raw_key_supervised, label_key,
150            patch_shape, batch_size, n_samples=n_samples_val,
151        )
152    else:
153        supervised_train_loader = None
154        supervised_val_loader = None
155
156    device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
157    trainer = self_training.MeanTeacherTrainer(
158        name=name,
159        model=model,
160        optimizer=optimizer,
161        lr_scheduler=scheduler,
162        pseudo_labeler=pseudo_labeler,
163        unsupervised_loss=loss,
164        unsupervised_loss_and_metric=loss_and_metric,
165        supervised_train_loader=supervised_train_loader,
166        unsupervised_train_loader=unsupervised_train_loader,
167        supervised_val_loader=supervised_val_loader,
168        unsupervised_val_loader=unsupervised_val_loader,
169        supervised_loss=loss,
170        supervised_loss_and_metric=loss_and_metric,
171        logger=self_training.SelfTrainingTensorboardLogger,
172        mixed_precision=True,
173        log_image_interval=100,
174        compile_model=False,
175        device=device,
176        reinit_teacher=reinit_teacher,
177        save_root=save_root,
178        sampler=pseudo_label_sampler,
179    )
180    trainer.fit(n_iterations)
181    
182    
183# TODO patch shapes for other models
184PATCH_SHAPES = {
185    "vesicles_3d": [48, 256, 256],
186}
187"""@private
188"""
189
190
191def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
192    files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
193    if len(files) == 0:
194        raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
195
196    # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
197    # And resave the volumes.
198    resave_val_crops = len(files) < 4
199
200    # We only resave the data if we resave val crops or resize the training data
201    resave_data = resave_val_crops or resize_training_data
202    if not resave_data:
203        train_paths, val_paths = train_test_split(files, test_size=val_fraction)
204        return train_paths, val_paths
205
206    train_paths, val_paths = [], []
207    for file_path in files:
208        file_name = os.path.basename(file_path)
209        data = open_file(file_path, mode="r")["data"][:]
210
211        if resize_training_data:
212            with mrcfile.open(file_path) as f:
213                voxel_size = f.voxel_size
214            voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
215            scale = compute_scale_from_voxel_size(voxel_size, model_name)
216            scaler = _Scaler(scale, verbose=False)
217            data = scaler.sale_input(data)
218
219        if resave_val_crops:
220            n_slices = data.shape[0]
221            val_slice = int((1.0 - val_fraction) * n_slices)
222            train_data, val_data = data[:val_slice], data[val_slice:]
223
224            train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
225            with open_file(train_path, mode="w") as f:
226                f.create_dataset("data", data=train_data, compression="lzf")
227            train_paths.append(train_path)
228
229            val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
230            with open_file(val_path, mode="w") as f:
231                f.create_dataset("data", data=val_data, compression="lzf")
232            val_paths.append(val_path)
233
234        else:
235            output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
236            with open_file(output_path, mode="w") as f:
237                f.create_dataset("data", data=data, compression="lzf")
238            train_paths.append(output_path)
239
240    if not resave_val_crops:
241        train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
242
243    return train_paths, val_paths
244
245
246def _parse_patch_shape(patch_shape, model_name):
247    if patch_shape is None:
248        patch_shape = PATCH_SHAPES[model_name]
249    return patch_shape
250
251def main():
252    """@private
253    """
254    import argparse
255
256    parser = argparse.ArgumentParser(
257        description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
258        "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
259        "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n"  # noqa
260        "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)."  # noqa
261        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
262        "Check out the information below for details on the arguments of this function.",
263        formatter_class=argparse.RawTextHelpFormatter
264    )
265    parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
266    parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
267    parser.add_argument("--file_pattern", default="*",
268                        help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
269    parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.")  # noqa
270    parser.add_argument(
271        "--source_model",
272        default="vesicles_3d",
273        help="The source model used for weight initialization of teacher and student model. "
274        "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
275    )
276    parser.add_argument(
277        "--resize_training_data", action="store_true",
278        help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
279    )
280    parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
281    parser.add_argument(
282        "--patch_shape", nargs=3, type=int,
283        help="The patch shape for training. By default the patch shape the source model was trained with is used."
284    )
285
286    # More optional argument:
287    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
288    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
289    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
290    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
291    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
292
293    args = parser.parse_args()
294
295    source_checkpoint = get_model_path(args.source_model)
296    patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
297    with tempfile.TemporaryDirectory() as tmp_dir:
298        unsupervised_train_paths, unsupervised_val_paths = _get_paths(
299            args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
300        )
301        unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
302
303        mean_teacher_adaptation(
304            name=args.name,
305            unsupervised_train_paths=unsupervised_train_paths,
306            unsupervised_val_paths=unsupervised_val_paths,
307            patch_shape=patch_shape,
308            source_checkpoint=source_checkpoint,
309            raw_key=raw_key,
310            n_iterations=args.n_iterations,
311            batch_size=args.batch_size,
312            n_samples_train=args.n_samples_train,
313            n_samples_val=args.n_samples_val,
314            check=args.check,
315        )
def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, train_mask_paths: Optional[Tuple[str]] = None, val_mask_paths: Optional[Tuple[str]] = None, patch_sampler: Optional[<built-in function callable>] = None, pseudo_label_sampler: Optional[<built-in function callable>] = None, device: int = 0) -> None:
 22def mean_teacher_adaptation(
 23    name: str,
 24    unsupervised_train_paths: Tuple[str],
 25    unsupervised_val_paths: Tuple[str],
 26    patch_shape: Tuple[int, int, int],
 27    save_root: Optional[str] = None,
 28    source_checkpoint: Optional[str] = None,
 29    supervised_train_paths: Optional[Tuple[str]] = None,
 30    supervised_val_paths: Optional[Tuple[str]] = None,
 31    confidence_threshold: float = 0.9,
 32    raw_key: str = "raw",
 33    raw_key_supervised: str = "raw",
 34    label_key: Optional[str] = None,
 35    batch_size: int = 1,
 36    lr: float = 1e-4,
 37    n_iterations: int = int(1e4),
 38    n_samples_train: Optional[int] = None,
 39    n_samples_val: Optional[int] = None,
 40    train_mask_paths: Optional[Tuple[str]] = None,
 41    val_mask_paths: Optional[Tuple[str]] = None,
 42    patch_sampler: Optional[callable] = None,
 43    pseudo_label_sampler: Optional[callable] = None,
 44    device: int = 0,
 45) -> None:
 46    """Run domain adaptation to transfer a network trained on a source domain for a supervised
 47    segmentation task to perform this task on a different target domain.
 48
 49    We support different domain adaptation settings:
 50    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 51     'supervised_val_paths' are not given.
 52    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 53      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 54
 55    Args:
 56        name: The name for the checkpoint to be trained.
 57        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 58            for the training data in the target domain.
 59            This training data is used for unsupervised learning, so it does not require labels.
 60        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 61            for the validation data in the target domain.
 62            This validation data is used for unsupervised learning, so it does not require labels.
 63        patch_shape: The patch shape used for a training example.
 64            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 65            e.g. 'patch_shape = [1, 512, 512]'.
 66        save_root: Folder where the checkpoint will be saved.
 67        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 68            This is used to initialize the teacher model.
 69            If the checkpoint is not given, then both student and teacher model are initialized
 70            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 71            be given in order to provide training data from the source domain.
 72        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 73            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 74        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 75            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 76        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 77            The label filtering is done based on the uncertainty of network predictions, and only
 78            the data with higher certainty than this threshold is used for training.
 79        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 80        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 81            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 82        batch_size: The batch size for training.
 83        lr: The initial learning rate.
 84        n_iterations: The number of iterations to train for.
 85        n_samples_train: The number of train samples per epoch. By default this will be estimated
 86            based on the patch_shape and size of the volumes used for training.
 87        n_samples_val: The number of val samples per epoch. By default this will be estimated
 88            based on the patch_shape and size of the volumes used for validation.
 89        train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training. 
 90        val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation. 
 91        patch_sampler: Accept or reject patches based on a condition.
 92        pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients. 
 93        device: GPU ID for training. 
 94    """
 95    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 96    is_2d, _ = _determine_ndim(patch_shape)
 97
 98    if source_checkpoint is None:
 99        # training from scratch only makes sense if we have supervised training data
100        # that's why we have the assertion here.
101        assert supervised_train_paths is not None
102        print("Mean teacher training from scratch (AdaMT)")
103        if is_2d:
104            model = get_2d_model(out_channels=2)
105        else:
106            model = get_3d_model(out_channels=2)
107        reinit_teacher = True
108    else:
109        print("Mean teacher training initialized from source model:", source_checkpoint)
110        if os.path.isdir(source_checkpoint):
111            model = torch_em.util.load_model(source_checkpoint)
112        else:
113            model = torch.load(source_checkpoint, weights_only=False)
114        reinit_teacher = False
115
116    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
117    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
118
119    # self training functionality
120    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
121    loss = self_training.DefaultSelfTrainingLoss()
122    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
123    
124    unsupervised_train_loader = get_unsupervised_loader(
125        data_paths=unsupervised_train_paths, 
126        raw_key=raw_key, 
127        patch_shape=patch_shape, 
128        batch_size=batch_size, 
129        n_samples=n_samples_train, 
130        sample_mask_paths=train_mask_paths, 
131        sampler=patch_sampler
132    )
133    unsupervised_val_loader = get_unsupervised_loader(
134        data_paths=unsupervised_val_paths, 
135        raw_key=raw_key, 
136        patch_shape=patch_shape, 
137        batch_size=batch_size, 
138        n_samples=n_samples_val, 
139        sample_mask_paths=val_mask_paths, 
140        sampler=patch_sampler
141    )
142
143    if supervised_train_paths is not None:
144        assert label_key is not None
145        supervised_train_loader = get_supervised_loader(
146            supervised_train_paths, raw_key_supervised, label_key,
147            patch_shape, batch_size, n_samples=n_samples_train,
148        )
149        supervised_val_loader = get_supervised_loader(
150            supervised_val_paths, raw_key_supervised, label_key,
151            patch_shape, batch_size, n_samples=n_samples_val,
152        )
153    else:
154        supervised_train_loader = None
155        supervised_val_loader = None
156
157    device = torch.device(f"cuda:{device}") if torch.cuda.is_available() else torch.device("cpu")
158    trainer = self_training.MeanTeacherTrainer(
159        name=name,
160        model=model,
161        optimizer=optimizer,
162        lr_scheduler=scheduler,
163        pseudo_labeler=pseudo_labeler,
164        unsupervised_loss=loss,
165        unsupervised_loss_and_metric=loss_and_metric,
166        supervised_train_loader=supervised_train_loader,
167        unsupervised_train_loader=unsupervised_train_loader,
168        supervised_val_loader=supervised_val_loader,
169        unsupervised_val_loader=unsupervised_val_loader,
170        supervised_loss=loss,
171        supervised_loss_and_metric=loss_and_metric,
172        logger=self_training.SelfTrainingTensorboardLogger,
173        mixed_precision=True,
174        log_image_interval=100,
175        compile_model=False,
176        device=device,
177        reinit_teacher=reinit_teacher,
178        save_root=save_root,
179        sampler=pseudo_label_sampler,
180    )
181    trainer.fit(n_iterations)

Run domain adaptation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:

  • unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
  • semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
  • name: The name for the checkpoint to be trained.
  • unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
  • unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • source_checkpoint: Checkpoint to the initial model trained on the source domain. This is used to initialize the teacher model. If the checkpoint is not given, then both student and teacher model are initialized from scratch. In this case supervised_train_paths and supervised_val_paths have to be given in order to provide training data from the source domain.
  • supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
  • supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
  • confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
  • raw_key: The key that holds the raw data inside of the hdf5 or similar files.
  • label_key: The key that holds the labels inside of the hdf5 files for supervised learning. This is only required if supervised_train_paths and supervised_val_paths are given.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • train_mask_paths: Sample masks used by the patch sampler to accept or reject patches for training.
  • val_mask_paths: Sample masks used by the patch sampler to accept or reject patches for validation.
  • patch_sampler: Accept or reject patches based on a condition.
  • pseudo_label_sampler: Mask out regions of the pseudo labels where the teacher is not confident before updating the gradients.
  • device: GPU ID for training.