synapse_net.training.domain_adaptation

  1import os
  2import tempfile
  3from glob import glob
  4from pathlib import Path
  5from typing import Optional, Tuple
  6
  7import mrcfile
  8import torch
  9import torch_em
 10import torch_em.self_training as self_training
 11from elf.io import open_file
 12from sklearn.model_selection import train_test_split
 13
 14from .semisupervised_training import get_unsupervised_loader
 15from .supervised_training import (
 16    get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim, _derive_key_from_files
 17)
 18from ..inference.inference import get_model_path, compute_scale_from_voxel_size
 19from ..inference.util import _Scaler
 20
 21
 22def mean_teacher_adaptation(
 23    name: str,
 24    unsupervised_train_paths: Tuple[str],
 25    unsupervised_val_paths: Tuple[str],
 26    patch_shape: Tuple[int, int, int],
 27    save_root: Optional[str] = None,
 28    source_checkpoint: Optional[str] = None,
 29    supervised_train_paths: Optional[Tuple[str]] = None,
 30    supervised_val_paths: Optional[Tuple[str]] = None,
 31    confidence_threshold: float = 0.9,
 32    raw_key: str = "raw",
 33    raw_key_supervised: str = "raw",
 34    label_key: Optional[str] = None,
 35    batch_size: int = 1,
 36    lr: float = 1e-4,
 37    n_iterations: int = int(1e4),
 38    n_samples_train: Optional[int] = None,
 39    n_samples_val: Optional[int] = None,
 40    sampler: Optional[callable] = None,
 41) -> None:
 42    """Run domain adapation to transfer a network trained on a source domain for a supervised
 43    segmentation task to perform this task on a different target domain.
 44
 45    We support different domain adaptation settings:
 46    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 47     'supervised_val_paths' are not given.
 48    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 49      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 50
 51    Args:
 52        name: The name for the checkpoint to be trained.
 53        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 54            for the training data in the target domain.
 55            This training data is used for unsupervised learning, so it does not require labels.
 56        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 57            for the validation data in the target domain.
 58            This validation data is used for unsupervised learning, so it does not require labels.
 59        patch_shape: The patch shape used for a training example.
 60            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 61            e.g. 'patch_shape = [1, 512, 512]'.
 62        save_root: Folder where the checkpoint will be saved.
 63        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 64            This is used to initialize the teacher model.
 65            If the checkpoint is not given, then both student and teacher model are initialized
 66            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 67            be given in order to provide training data from the source domain.
 68        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 69            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 70        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 71            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 72        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 73            The label filtering is done based on the uncertainty of network predictions, and only
 74            the data with higher certainty than this threshold is used for training.
 75        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 76        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 77            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 78        batch_size: The batch size for training.
 79        lr: The initial learning rate.
 80        n_iterations: The number of iterations to train for.
 81        n_samples_train: The number of train samples per epoch. By default this will be estimated
 82            based on the patch_shape and size of the volumes used for training.
 83        n_samples_val: The number of val samples per epoch. By default this will be estimated
 84            based on the patch_shape and size of the volumes used for validation.
 85    """
 86    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 87    is_2d, _ = _determine_ndim(patch_shape)
 88
 89    if source_checkpoint is None:
 90        # training from scratch only makes sense if we have supervised training data
 91        # that's why we have the assertion here.
 92        assert supervised_train_paths is not None
 93        print("Mean teacher training from scratch (AdaMT)")
 94        if is_2d:
 95            model = get_2d_model(out_channels=2)
 96        else:
 97            model = get_3d_model(out_channels=2)
 98        reinit_teacher = True
 99    else:
100        print("Mean teacehr training initialized from source model:", source_checkpoint)
101        if os.path.isdir(source_checkpoint):
102            model = torch_em.util.load_model(source_checkpoint)
103        else:
104            model = torch.load(source_checkpoint, weights_only=False)
105        reinit_teacher = False
106
107    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
108    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
109
110    # self training functionality
111    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
112    loss = self_training.DefaultSelfTrainingLoss()
113    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
114
115    unsupervised_train_loader = get_unsupervised_loader(
116        unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
117    )
118    unsupervised_val_loader = get_unsupervised_loader(
119        unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
120    )
121
122    if supervised_train_paths is not None:
123        assert label_key is not None
124        supervised_train_loader = get_supervised_loader(
125            supervised_train_paths, raw_key_supervised, label_key,
126            patch_shape, batch_size, n_samples=n_samples_train,
127        )
128        supervised_val_loader = get_supervised_loader(
129            supervised_val_paths, raw_key_supervised, label_key,
130            patch_shape, batch_size, n_samples=n_samples_val,
131        )
132    else:
133        supervised_train_loader = None
134        supervised_val_loader = None
135
136    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
137    trainer = self_training.MeanTeacherTrainer(
138        name=name,
139        model=model,
140        optimizer=optimizer,
141        lr_scheduler=scheduler,
142        pseudo_labeler=pseudo_labeler,
143        unsupervised_loss=loss,
144        unsupervised_loss_and_metric=loss_and_metric,
145        supervised_train_loader=supervised_train_loader,
146        unsupervised_train_loader=unsupervised_train_loader,
147        supervised_val_loader=supervised_val_loader,
148        unsupervised_val_loader=unsupervised_val_loader,
149        supervised_loss=loss,
150        supervised_loss_and_metric=loss_and_metric,
151        logger=self_training.SelfTrainingTensorboardLogger,
152        mixed_precision=True,
153        log_image_interval=100,
154        compile_model=False,
155        device=device,
156        reinit_teacher=reinit_teacher,
157        save_root=save_root,
158        sampler=sampler,
159    )
160    trainer.fit(n_iterations)
161
162
163# TODO patch shapes for other models
164PATCH_SHAPES = {
165    "vesicles_3d": [48, 256, 256],
166}
167"""@private
168"""
169
170
171def _get_paths(input_folder, pattern, resize_training_data, model_name, tmp_dir, val_fraction):
172    files = sorted(glob(os.path.join(input_folder, "**", pattern), recursive=True))
173    if len(files) == 0:
174        raise ValueError(f"Could not load any files from {input_folder} with pattern {pattern}")
175
176    # Heuristic: if we have less then 4 files then we crop a part of the volumes for validation.
177    # And resave the volumes.
178    resave_val_crops = len(files) < 4
179
180    # We only resave the data if we resave val crops or resize the training data
181    resave_data = resave_val_crops or resize_training_data
182    if not resave_data:
183        train_paths, val_paths = train_test_split(files, test_size=val_fraction)
184        return train_paths, val_paths
185
186    train_paths, val_paths = [], []
187    for file_path in files:
188        file_name = os.path.basename(file_path)
189        data = open_file(file_path, mode="r")["data"][:]
190
191        if resize_training_data:
192            with mrcfile.open(file_path) as f:
193                voxel_size = f.voxel_size
194            voxel_size = {ax: vox_size / 10.0 for ax, vox_size in zip("xyz", voxel_size.item())}
195            scale = compute_scale_from_voxel_size(voxel_size, model_name)
196            scaler = _Scaler(scale, verbose=False)
197            data = scaler.sale_input(data)
198
199        if resave_val_crops:
200            n_slices = data.shape[0]
201            val_slice = int((1.0 - val_fraction) * n_slices)
202            train_data, val_data = data[:val_slice], data[val_slice:]
203
204            train_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_train.h5")
205            with open_file(train_path, mode="w") as f:
206                f.create_dataset("data", data=train_data, compression="lzf")
207            train_paths.append(train_path)
208
209            val_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5")).replace(".h5", "_val.h5")
210            with open_file(val_path, mode="w") as f:
211                f.create_dataset("data", data=val_data, compression="lzf")
212            val_paths.append(val_path)
213
214        else:
215            output_path = os.path.join(tmp_dir, Path(file_name).with_suffix(".h5"))
216            with open_file(output_path, mode="w") as f:
217                f.create_dataset("data", data=data, compression="lzf")
218            train_paths.append(output_path)
219
220    if not resave_val_crops:
221        train_paths, val_paths = train_test_split(train_paths, test_size=val_fraction)
222
223    return train_paths, val_paths
224
225
226def _parse_patch_shape(patch_shape, model_name):
227    if patch_shape is None:
228        patch_shape = PATCH_SHAPES[model_name]
229    return patch_shape
230
231
232def main():
233    """@private
234    """
235    import argparse
236
237    parser = argparse.ArgumentParser(
238        description="Adapt a model to data from a different domain using unsupervised domain adaptation.\n\n"
239        "You can use this function to adapt the SynapseNet model for vesicle segmentation like this:\n"
240        "synapse_net.run_domain_adaptation -n adapted_model -i /path/to/data --file_pattern *.mrc --source_model vesicles_3d\n"  # noqa
241        "The trained model will be saved in the folder 'checkpoints/adapted_model' (or whichever name you pass to the '-n' argument)."  # noqa
242        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
243        "Check out the information below for details on the arguments of this function.",
244        formatter_class=argparse.RawTextHelpFormatter
245    )
246    parser.add_argument("--name", "-n", required=True, help="The name of the model to be trained. ")
247    parser.add_argument("--input_folder", "-i", required=True, help="The folder with the training data.")
248    parser.add_argument("--file_pattern", default="*",
249                        help="The pattern for selecting files for training. For example '*.mrc' to select mrc files.")
250    parser.add_argument("--key", help="The internal file path for the training data. Will be derived from the file extension by default.")  # noqa
251    parser.add_argument(
252        "--source_model",
253        default="vesicles_3d",
254        help="The source model used for weight initialization of teacher and student model. "
255        "By default the model 'vesicles_3d' for vesicle segmentation in volumetric data is used."
256    )
257    parser.add_argument(
258        "--resize_training_data", action="store_true",
259        help="Whether to resize the training data to fit the voxel size of the source model's trainign data."
260    )
261    parser.add_argument("--n_iterations", type=int, default=int(1e4), help="The number of iterations for training.")
262    parser.add_argument(
263        "--patch_shape", nargs=3, type=int,
264        help="The patch shape for training. By default the patch shape the source model was trained with is used."
265    )
266
267    # More optional argument:
268    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
269    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
270    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
271    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
272    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
273
274    args = parser.parse_args()
275
276    source_checkpoint = get_model_path(args.source_model)
277    patch_shape = _parse_patch_shape(args.patch_shape, args.source_model)
278    with tempfile.TemporaryDirectory() as tmp_dir:
279        unsupervised_train_paths, unsupervised_val_paths = _get_paths(
280            args.input, args.pattern, args.resize_training_data, args.source_model, tmp_dir, args.val_fraction,
281        )
282        unsupervised_train_paths, raw_key = _derive_key_from_files(unsupervised_train_paths, args.key)
283
284        mean_teacher_adaptation(
285            name=args.name,
286            unsupervised_train_paths=unsupervised_train_paths,
287            unsupervised_val_paths=unsupervised_val_paths,
288            patch_shape=patch_shape,
289            source_checkpoint=source_checkpoint,
290            raw_key=raw_key,
291            n_iterations=args.n_iterations,
292            batch_size=args.batch_size,
293            n_samples_train=args.n_samples_train,
294            n_samples_val=args.n_samples_val,
295            check=args.check,
296        )
def mean_teacher_adaptation( name: str, unsupervised_train_paths: Tuple[str], unsupervised_val_paths: Tuple[str], patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, source_checkpoint: Optional[str] = None, supervised_train_paths: Optional[Tuple[str]] = None, supervised_val_paths: Optional[Tuple[str]] = None, confidence_threshold: float = 0.9, raw_key: str = 'raw', raw_key_supervised: str = 'raw', label_key: Optional[str] = None, batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 10000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, sampler: Optional[<built-in function callable>] = None) -> None:
 23def mean_teacher_adaptation(
 24    name: str,
 25    unsupervised_train_paths: Tuple[str],
 26    unsupervised_val_paths: Tuple[str],
 27    patch_shape: Tuple[int, int, int],
 28    save_root: Optional[str] = None,
 29    source_checkpoint: Optional[str] = None,
 30    supervised_train_paths: Optional[Tuple[str]] = None,
 31    supervised_val_paths: Optional[Tuple[str]] = None,
 32    confidence_threshold: float = 0.9,
 33    raw_key: str = "raw",
 34    raw_key_supervised: str = "raw",
 35    label_key: Optional[str] = None,
 36    batch_size: int = 1,
 37    lr: float = 1e-4,
 38    n_iterations: int = int(1e4),
 39    n_samples_train: Optional[int] = None,
 40    n_samples_val: Optional[int] = None,
 41    sampler: Optional[callable] = None,
 42) -> None:
 43    """Run domain adapation to transfer a network trained on a source domain for a supervised
 44    segmentation task to perform this task on a different target domain.
 45
 46    We support different domain adaptation settings:
 47    - unsupervised domain adaptation: the default mode when 'supervised_train_paths' and
 48     'supervised_val_paths' are not given.
 49    - semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data,
 50      when 'supervised_train_paths' and 'supervised_val_paths' are given.
 51
 52    Args:
 53        name: The name for the checkpoint to be trained.
 54        unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
 55            for the training data in the target domain.
 56            This training data is used for unsupervised learning, so it does not require labels.
 57        unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
 58            for the validation data in the target domain.
 59            This validation data is used for unsupervised learning, so it does not require labels.
 60        patch_shape: The patch shape used for a training example.
 61            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 62            e.g. 'patch_shape = [1, 512, 512]'.
 63        save_root: Folder where the checkpoint will be saved.
 64        source_checkpoint: Checkpoint to the initial model trained on the source domain.
 65            This is used to initialize the teacher model.
 66            If the checkpoint is not given, then both student and teacher model are initialized
 67            from scratch. In this case `supervised_train_paths` and `supervised_val_paths` have to
 68            be given in order to provide training data from the source domain.
 69        supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain.
 70            This training data is optional. If given, it is used for unsupervised learnig and requires labels.
 71        supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain.
 72            This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
 73        confidence_threshold: The threshold for filtering data in the unsupervised loss.
 74            The label filtering is done based on the uncertainty of network predictions, and only
 75            the data with higher certainty than this threshold is used for training.
 76        raw_key: The key that holds the raw data inside of the hdf5 or similar files.
 77        label_key: The key that holds the labels inside of the hdf5 files for supervised learning.
 78            This is only required if `supervised_train_paths` and `supervised_val_paths` are given.
 79        batch_size: The batch size for training.
 80        lr: The initial learning rate.
 81        n_iterations: The number of iterations to train for.
 82        n_samples_train: The number of train samples per epoch. By default this will be estimated
 83            based on the patch_shape and size of the volumes used for training.
 84        n_samples_val: The number of val samples per epoch. By default this will be estimated
 85            based on the patch_shape and size of the volumes used for validation.
 86    """
 87    assert (supervised_train_paths is None) == (supervised_val_paths is None)
 88    is_2d, _ = _determine_ndim(patch_shape)
 89
 90    if source_checkpoint is None:
 91        # training from scratch only makes sense if we have supervised training data
 92        # that's why we have the assertion here.
 93        assert supervised_train_paths is not None
 94        print("Mean teacher training from scratch (AdaMT)")
 95        if is_2d:
 96            model = get_2d_model(out_channels=2)
 97        else:
 98            model = get_3d_model(out_channels=2)
 99        reinit_teacher = True
100    else:
101        print("Mean teacehr training initialized from source model:", source_checkpoint)
102        if os.path.isdir(source_checkpoint):
103            model = torch_em.util.load_model(source_checkpoint)
104        else:
105            model = torch.load(source_checkpoint, weights_only=False)
106        reinit_teacher = False
107
108    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
109    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
110
111    # self training functionality
112    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=confidence_threshold)
113    loss = self_training.DefaultSelfTrainingLoss()
114    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
115
116    unsupervised_train_loader = get_unsupervised_loader(
117        unsupervised_train_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_train
118    )
119    unsupervised_val_loader = get_unsupervised_loader(
120        unsupervised_val_paths, raw_key, patch_shape, batch_size, n_samples=n_samples_val
121    )
122
123    if supervised_train_paths is not None:
124        assert label_key is not None
125        supervised_train_loader = get_supervised_loader(
126            supervised_train_paths, raw_key_supervised, label_key,
127            patch_shape, batch_size, n_samples=n_samples_train,
128        )
129        supervised_val_loader = get_supervised_loader(
130            supervised_val_paths, raw_key_supervised, label_key,
131            patch_shape, batch_size, n_samples=n_samples_val,
132        )
133    else:
134        supervised_train_loader = None
135        supervised_val_loader = None
136
137    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
138    trainer = self_training.MeanTeacherTrainer(
139        name=name,
140        model=model,
141        optimizer=optimizer,
142        lr_scheduler=scheduler,
143        pseudo_labeler=pseudo_labeler,
144        unsupervised_loss=loss,
145        unsupervised_loss_and_metric=loss_and_metric,
146        supervised_train_loader=supervised_train_loader,
147        unsupervised_train_loader=unsupervised_train_loader,
148        supervised_val_loader=supervised_val_loader,
149        unsupervised_val_loader=unsupervised_val_loader,
150        supervised_loss=loss,
151        supervised_loss_and_metric=loss_and_metric,
152        logger=self_training.SelfTrainingTensorboardLogger,
153        mixed_precision=True,
154        log_image_interval=100,
155        compile_model=False,
156        device=device,
157        reinit_teacher=reinit_teacher,
158        save_root=save_root,
159        sampler=sampler,
160    )
161    trainer.fit(n_iterations)

Run domain adapation to transfer a network trained on a source domain for a supervised segmentation task to perform this task on a different target domain.

We support different domain adaptation settings:

  • unsupervised domain adaptation: the default mode when 'supervised_train_paths' and 'supervised_val_paths' are not given.
  • semi-supervised domain adaptation: domain adaptation on unlabeled and labeled data, when 'supervised_train_paths' and 'supervised_val_paths' are given.
Arguments:
  • name: The name for the checkpoint to be trained.
  • unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats for the training data in the target domain. This training data is used for unsupervised learning, so it does not require labels.
  • unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats for the validation data in the target domain. This validation data is used for unsupervised learning, so it does not require labels.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • source_checkpoint: Checkpoint to the initial model trained on the source domain. This is used to initialize the teacher model. If the checkpoint is not given, then both student and teacher model are initialized from scratch. In this case supervised_train_paths and supervised_val_paths have to be given in order to provide training data from the source domain.
  • supervised_train_paths: Filepaths to the hdf5 files for the training data in the source domain. This training data is optional. If given, it is used for unsupervised learnig and requires labels.
  • supervised_val_paths: Filepaths to the df5 files for the validation data in the source domain. This validation data is optional. If given, it is used for unsupervised learnig and requires labels.
  • confidence_threshold: The threshold for filtering data in the unsupervised loss. The label filtering is done based on the uncertainty of network predictions, and only the data with higher certainty than this threshold is used for training.
  • raw_key: The key that holds the raw data inside of the hdf5 or similar files.
  • label_key: The key that holds the labels inside of the hdf5 files for supervised learning. This is only required if supervised_train_paths and supervised_val_paths are given.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.