synapse_net.training.semisupervised_training

  1from typing import Optional, Tuple
  2
  3import numpy as np
  4import torch
  5import torch_em
  6import torch_em.self_training as self_training
  7from torchvision import transforms
  8
  9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
 10
 11
 12def weak_augmentations(p: float = 0.75) -> callable:
 13    """The weak augmentations used in the unsupervised data loader.
 14
 15    Args:
 16        p: The probability for applying one of the augmentations.
 17
 18    Returns:
 19        The transformation function applying the augmentation.
 20    """
 21    norm = torch_em.transform.raw.standardize
 22    aug = transforms.Compose([
 23        norm,
 24        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
 25        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
 26            scale=(0, 0.15), clip_kwargs=False)], p=p
 27        ),
 28    ])
 29    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
 30
 31
 32def get_unsupervised_loader(
 33    data_paths: Tuple[str],
 34    raw_key: str,
 35    patch_shape: Tuple[int, int, int],
 36    batch_size: int,
 37    n_samples: Optional[int],
 38    exclude_top_and_bottom: bool = False,
 39) -> torch.utils.data.DataLoader:
 40    """Get a dataloader for unsupervised segmentation training.
 41
 42    Args:
 43        data_paths: The filepaths to the hdf5 files containing the training data.
 44        raw_key: The key that holds the raw data inside of the hdf5.
 45        patch_shape: The patch shape used for a training example.
 46            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 47            e.g. 'patch_shape = [1, 512, 512]'.
 48        batch_size: The batch size for training.
 49        n_samples: The number of samples per epoch. By default this will be estimated
 50            based on the patch_shape and size of the volumes used for training.
 51        exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
 52            avoid artifacts at the border of tomograms.
 53
 54    Returns:
 55        The PyTorch dataloader.
 56    """
 57
 58    # We exclude the top and bottom slices where the tomogram reconstruction is bad.
 59    if exclude_top_and_bottom:
 60        roi = np.s_[5:-5, :, :]
 61    else:
 62        roi = None
 63
 64    _, ndim = _determine_ndim(patch_shape)
 65    raw_transform = torch_em.transform.get_raw_transform()
 66    transform = torch_em.transform.get_augmentations(ndim=ndim)
 67
 68    if n_samples is None:
 69        n_samples_per_ds = None
 70    else:
 71        n_samples_per_ds = int(n_samples / len(data_paths))
 72
 73    augmentations = (weak_augmentations(), weak_augmentations())
 74    datasets = [
 75        torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
 76                                 augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
 77        for path in data_paths
 78    ]
 79    ds = torch.utils.data.ConcatDataset(datasets)
 80
 81    num_workers = 4 * batch_size
 82    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
 83    return loader
 84
 85
 86# TODO: use different paths for supervised and unsupervised training
 87# (We are currently not using this functionality directly, so this is not a high priority)
 88def semisupervised_training(
 89    name: str,
 90    train_paths: Tuple[str],
 91    val_paths: Tuple[str],
 92    label_key: str,
 93    patch_shape: Tuple[int, int, int],
 94    save_root: str,
 95    raw_key: str = "raw",
 96    batch_size: int = 1,
 97    lr: float = 1e-4,
 98    n_iterations: int = int(1e5),
 99    n_samples_train: Optional[int] = None,
100    n_samples_val: Optional[int] = None,
101    check: bool = False,
102) -> None:
103    """Run semi-supervised segmentation training.
104
105    Args:
106        name: The name for the checkpoint to be trained.
107        train_paths: Filepaths to the hdf5 files for the training data.
108        val_paths: Filepaths to the df5 files for the validation data.
109        label_key: The key that holds the labels inside of the hdf5.
110        patch_shape: The patch shape used for a training example.
111            In order to run 2d training pass a patch shape with a singleton in the z-axis,
112            e.g. 'patch_shape = [1, 512, 512]'.
113        save_root: Folder where the checkpoint will be saved.
114        raw_key: The key that holds the raw data inside of the hdf5.
115        batch_size: The batch size for training.
116        lr: The initial learning rate.
117        n_iterations: The number of iterations to train for.
118        n_samples_train: The number of train samples per epoch. By default this will be estimated
119            based on the patch_shape and size of the volumes used for training.
120        n_samples_val: The number of val samples per epoch. By default this will be estimated
121            based on the patch_shape and size of the volumes used for validation.
122        check: Whether to check the training and validation loaders instead of running training.
123    """
124    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
125                                         n_samples=n_samples_train)
126    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
127                                       n_samples=n_samples_val)
128
129    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
130                                                        n_samples=n_samples_train)
131    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
132                                                      n_samples=n_samples_val)
133
134    # TODO check the semisup loader
135    if check:
136        # from torch_em.util.debug import check_loader
137        # check_loader(train_loader, n_samples=4)
138        # check_loader(val_loader, n_samples=4)
139        return
140
141    # Check for 2D or 3D training
142    is_2d = False
143    z, y, x = patch_shape
144    is_2d = z == 1
145
146    if is_2d:
147        model = get_2d_model(out_channels=2)
148    else:
149        model = get_3d_model(out_channels=2)
150
151    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
152    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
153
154    # Self training functionality.
155    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
156    loss = self_training.DefaultSelfTrainingLoss()
157    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
158
159    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
160    trainer = self_training.MeanTeacherTrainer(
161        name=name,
162        model=model,
163        optimizer=optimizer,
164        lr_scheduler=scheduler,
165        pseudo_labeler=pseudo_labeler,
166        unsupervised_loss=loss,
167        unsupervised_loss_and_metric=loss_and_metric,
168        supervised_train_loader=train_loader,
169        unsupervised_train_loader=unsupervised_train_loader,
170        supervised_val_loader=val_loader,
171        unsupervised_val_loader=unsupervised_val_loader,
172        supervised_loss=loss,
173        supervised_loss_and_metric=loss_and_metric,
174        logger=self_training.SelfTrainingTensorboardLogger,
175        mixed_precision=True,
176        device=device,
177        log_image_interval=100,
178        compile_model=False,
179        save_root=save_root,
180    )
181    trainer.fit(n_iterations)
def weak_augmentations(p: float = 0.75) -> <built-in function callable>:
13def weak_augmentations(p: float = 0.75) -> callable:
14    """The weak augmentations used in the unsupervised data loader.
15
16    Args:
17        p: The probability for applying one of the augmentations.
18
19    Returns:
20        The transformation function applying the augmentation.
21    """
22    norm = torch_em.transform.raw.standardize
23    aug = transforms.Compose([
24        norm,
25        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
26        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
27            scale=(0, 0.15), clip_kwargs=False)], p=p
28        ),
29    ])
30    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)

The weak augmentations used in the unsupervised data loader.

Arguments:
  • p: The probability for applying one of the augmentations.
Returns:

The transformation function applying the augmentation.

def get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
33def get_unsupervised_loader(
34    data_paths: Tuple[str],
35    raw_key: str,
36    patch_shape: Tuple[int, int, int],
37    batch_size: int,
38    n_samples: Optional[int],
39    exclude_top_and_bottom: bool = False,
40) -> torch.utils.data.DataLoader:
41    """Get a dataloader for unsupervised segmentation training.
42
43    Args:
44        data_paths: The filepaths to the hdf5 files containing the training data.
45        raw_key: The key that holds the raw data inside of the hdf5.
46        patch_shape: The patch shape used for a training example.
47            In order to run 2d training pass a patch shape with a singleton in the z-axis,
48            e.g. 'patch_shape = [1, 512, 512]'.
49        batch_size: The batch size for training.
50        n_samples: The number of samples per epoch. By default this will be estimated
51            based on the patch_shape and size of the volumes used for training.
52        exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
53            avoid artifacts at the border of tomograms.
54
55    Returns:
56        The PyTorch dataloader.
57    """
58
59    # We exclude the top and bottom slices where the tomogram reconstruction is bad.
60    if exclude_top_and_bottom:
61        roi = np.s_[5:-5, :, :]
62    else:
63        roi = None
64
65    _, ndim = _determine_ndim(patch_shape)
66    raw_transform = torch_em.transform.get_raw_transform()
67    transform = torch_em.transform.get_augmentations(ndim=ndim)
68
69    if n_samples is None:
70        n_samples_per_ds = None
71    else:
72        n_samples_per_ds = int(n_samples / len(data_paths))
73
74    augmentations = (weak_augmentations(), weak_augmentations())
75    datasets = [
76        torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform,
77                                 augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds)
78        for path in data_paths
79    ]
80    ds = torch.utils.data.ConcatDataset(datasets)
81
82    num_workers = 4 * batch_size
83    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True)
84    return loader

Get a dataloader for unsupervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • exclude_top_and_bottom: Whether to exluce the five top and bottom slices to avoid artifacts at the border of tomograms.
Returns:

The PyTorch dataloader.

def semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
 89def semisupervised_training(
 90    name: str,
 91    train_paths: Tuple[str],
 92    val_paths: Tuple[str],
 93    label_key: str,
 94    patch_shape: Tuple[int, int, int],
 95    save_root: str,
 96    raw_key: str = "raw",
 97    batch_size: int = 1,
 98    lr: float = 1e-4,
 99    n_iterations: int = int(1e5),
100    n_samples_train: Optional[int] = None,
101    n_samples_val: Optional[int] = None,
102    check: bool = False,
103) -> None:
104    """Run semi-supervised segmentation training.
105
106    Args:
107        name: The name for the checkpoint to be trained.
108        train_paths: Filepaths to the hdf5 files for the training data.
109        val_paths: Filepaths to the df5 files for the validation data.
110        label_key: The key that holds the labels inside of the hdf5.
111        patch_shape: The patch shape used for a training example.
112            In order to run 2d training pass a patch shape with a singleton in the z-axis,
113            e.g. 'patch_shape = [1, 512, 512]'.
114        save_root: Folder where the checkpoint will be saved.
115        raw_key: The key that holds the raw data inside of the hdf5.
116        batch_size: The batch size for training.
117        lr: The initial learning rate.
118        n_iterations: The number of iterations to train for.
119        n_samples_train: The number of train samples per epoch. By default this will be estimated
120            based on the patch_shape and size of the volumes used for training.
121        n_samples_val: The number of val samples per epoch. By default this will be estimated
122            based on the patch_shape and size of the volumes used for validation.
123        check: Whether to check the training and validation loaders instead of running training.
124    """
125    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
126                                         n_samples=n_samples_train)
127    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
128                                       n_samples=n_samples_val)
129
130    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
131                                                        n_samples=n_samples_train)
132    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
133                                                      n_samples=n_samples_val)
134
135    # TODO check the semisup loader
136    if check:
137        # from torch_em.util.debug import check_loader
138        # check_loader(train_loader, n_samples=4)
139        # check_loader(val_loader, n_samples=4)
140        return
141
142    # Check for 2D or 3D training
143    is_2d = False
144    z, y, x = patch_shape
145    is_2d = z == 1
146
147    if is_2d:
148        model = get_2d_model(out_channels=2)
149    else:
150        model = get_3d_model(out_channels=2)
151
152    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
153    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
154
155    # Self training functionality.
156    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
157    loss = self_training.DefaultSelfTrainingLoss()
158    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
159
160    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
161    trainer = self_training.MeanTeacherTrainer(
162        name=name,
163        model=model,
164        optimizer=optimizer,
165        lr_scheduler=scheduler,
166        pseudo_labeler=pseudo_labeler,
167        unsupervised_loss=loss,
168        unsupervised_loss_and_metric=loss_and_metric,
169        supervised_train_loader=train_loader,
170        unsupervised_train_loader=unsupervised_train_loader,
171        supervised_val_loader=val_loader,
172        unsupervised_val_loader=unsupervised_val_loader,
173        supervised_loss=loss,
174        supervised_loss_and_metric=loss_and_metric,
175        logger=self_training.SelfTrainingTensorboardLogger,
176        mixed_precision=True,
177        device=device,
178        log_image_interval=100,
179        compile_model=False,
180        save_root=save_root,
181    )
182    trainer.fit(n_iterations)

Run semi-supervised segmentation training.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.