synapse_net.training.semisupervised_training

  1from typing import Optional, Tuple
  2
  3import torch
  4import torch_em
  5import torch_em.self_training as self_training
  6from torchvision import transforms
  7from torch_em.data import RawDatasetWithMasks
  8
  9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
 10
 11
 12def weak_augmentations(p: float = 0.75) -> callable:
 13    """The weak augmentations used in the unsupervised data loader.
 14
 15    Args:
 16        p: The probability for applying one of the augmentations.
 17
 18    Returns:
 19        The transformation function applying the augmentation.
 20    """
 21    norm = torch_em.transform.raw.standardize
 22    aug = transforms.Compose([
 23        norm,
 24        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
 25        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
 26            scale=(0, 0.15), clip_kwargs=False)], p=p
 27        ),
 28    ])
 29    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
 30
 31
 32def get_unsupervised_loader(
 33    data_paths: Tuple[str],
 34    raw_key: str,
 35    patch_shape: Tuple[int, int, int],
 36    batch_size: int,
 37    n_samples: Optional[int],
 38    sample_mask_paths: Optional[Tuple[str]] = None,
 39    sample_mask_key: Optional[str] = None,
 40    bg_mask_paths: Optional[Tuple[str]] = None,
 41    bg_mask_key: Optional[str] = None,
 42    sampler: Optional[callable] = None,
 43    exclude_top_and_bottom: bool = False,
 44) -> torch.utils.data.DataLoader:
 45    """Get a dataloader for unsupervised segmentation training.
 46
 47    Args:
 48        data_paths: The filepaths to the hdf5 files containing the training data.
 49        raw_key: The key that holds the raw data inside of the hdf5.
 50        patch_shape: The patch shape used for a training example.
 51            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 52            e.g. 'patch_shape = [1, 512, 512]'.
 53        batch_size: The batch size for training.
 54        n_samples: The number of samples per epoch. By default this will be estimated
 55            based on the patch_shape and size of the volumes used for training.
 56        sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
 57        sample_mask_key: The key to the sample mask dataset inside each file.
 58        bg_mask_paths: The filepaths to the background masks for each tomogram.
 59        bg_mask_key: The key to the background mask dataset inside each file.
 60        sampler: Optional sampler to accept or reject patches for training. 
 61        exclude_top_and_bottom: Whether to exclude the five top and bottom slices to
 62            avoid artifacts at the border of tomograms.
 63
 64    Returns:
 65        The PyTorch dataloader.
 66    """
 67    if exclude_top_and_bottom:
 68        roi = (slice(5, -5), slice(None), slice(None))
 69    else:
 70        roi = None
 71
 72    if sample_mask_paths is not None:
 73        assert len(data_paths) == len(sample_mask_paths), \
 74            f"Expected equal number of data_paths and sample_mask_paths, got {len(data_paths)} and {len(sample_mask_paths)}."
 75    if bg_mask_paths is not None:
 76        assert len(data_paths) == len(bg_mask_paths), \
 77            f"Expected equal number of data_paths and bg_mask_paths, got {len(data_paths)} and {len(bg_mask_paths)}."
 78
 79    _, ndim = _determine_ndim(patch_shape)
 80    raw_transform = torch_em.transform.get_raw_transform()
 81    transform = torch_em.transform.get_augmentations(ndim=ndim)
 82    # augmentations = (weak_augmentations(), weak_augmentations())
 83
 84    if n_samples is None:
 85        n_samples_per_ds = None
 86    else:
 87        n_samples_per_ds = int(n_samples / len(data_paths))
 88
 89    datasets = [
 90        RawDatasetWithMasks(
 91            raw_path=data_path,
 92            raw_key=raw_key,
 93            patch_shape=patch_shape,
 94            raw_transform=raw_transform,
 95            transform=transform,
 96            roi=roi,
 97            n_samples=n_samples_per_ds,
 98            sampler=sampler,
 99            ndim=ndim,
100            augmentations=None,
101            sample_mask_path=sample_mask_paths[i] if sample_mask_paths is not None else None,
102            sample_mask_key=sample_mask_key,
103            bg_mask_path=bg_mask_paths[i] if bg_mask_paths is not None else None,
104            bg_mask_key=bg_mask_key,
105        )
106        for i, data_path in enumerate(data_paths)
107    ]
108    ds = torch.utils.data.ConcatDataset(datasets)
109
110    num_workers = 4 * batch_size
111    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
112                                                   num_workers=num_workers, shuffle=True)
113    return loader
114
115
116# TODO: use different paths for supervised and unsupervised training
117# (We are currently not using this functionality directly, so this is not a high priority)
118def semisupervised_training(
119    name: str,
120    train_paths: Tuple[str],
121    val_paths: Tuple[str],
122    label_key: str,
123    patch_shape: Tuple[int, int, int],
124    save_root: str,
125    raw_key: str = "raw",
126    batch_size: int = 1,
127    lr: float = 1e-4,
128    n_iterations: int = int(1e5),
129    n_samples_train: Optional[int] = None,
130    n_samples_val: Optional[int] = None,
131    check: bool = False,
132) -> None:
133    """Run semi-supervised segmentation training.
134
135    Args:
136        name: The name for the checkpoint to be trained.
137        train_paths: Filepaths to the hdf5 files for the training data.
138        val_paths: Filepaths to the df5 files for the validation data.
139        label_key: The key that holds the labels inside of the hdf5.
140        patch_shape: The patch shape used for a training example.
141            In order to run 2d training pass a patch shape with a singleton in the z-axis,
142            e.g. 'patch_shape = [1, 512, 512]'.
143        save_root: Folder where the checkpoint will be saved.
144        raw_key: The key that holds the raw data inside of the hdf5.
145        batch_size: The batch size for training.
146        lr: The initial learning rate.
147        n_iterations: The number of iterations to train for.
148        n_samples_train: The number of train samples per epoch. By default this will be estimated
149            based on the patch_shape and size of the volumes used for training.
150        n_samples_val: The number of val samples per epoch. By default this will be estimated
151            based on the patch_shape and size of the volumes used for validation.
152        check: Whether to check the training and validation loaders instead of running training.
153    """
154    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
155                                         n_samples=n_samples_train)
156    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
157                                       n_samples=n_samples_val)
158
159    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
160                                                        n_samples=n_samples_train)
161    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
162                                                      n_samples=n_samples_val)
163
164    # TODO check the semisup loader
165    if check:
166        # from torch_em.util.debug import check_loader
167        # check_loader(train_loader, n_samples=4)
168        # check_loader(val_loader, n_samples=4)
169        return
170
171    # Check for 2D or 3D training
172    is_2d = False
173    z, y, x = patch_shape
174    is_2d = z == 1
175
176    if is_2d:
177        model = get_2d_model(out_channels=2)
178    else:
179        model = get_3d_model(out_channels=2)
180
181    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
182    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
183
184    # Self training functionality.
185    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
186    loss = self_training.DefaultSelfTrainingLoss()
187    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
188
189    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
190    trainer = self_training.MeanTeacherTrainer(
191        name=name,
192        model=model,
193        optimizer=optimizer,
194        lr_scheduler=scheduler,
195        pseudo_labeler=pseudo_labeler,
196        unsupervised_loss=loss,
197        unsupervised_loss_and_metric=loss_and_metric,
198        supervised_train_loader=train_loader,
199        unsupervised_train_loader=unsupervised_train_loader,
200        supervised_val_loader=val_loader,
201        unsupervised_val_loader=unsupervised_val_loader,
202        supervised_loss=loss,
203        supervised_loss_and_metric=loss_and_metric,
204        logger=self_training.SelfTrainingTensorboardLogger,
205        mixed_precision=True,
206        device=device,
207        log_image_interval=100,
208        compile_model=False,
209        save_root=save_root,
210    )
211    trainer.fit(n_iterations)
def weak_augmentations(p: float = 0.75) -> <built-in function callable>:
13def weak_augmentations(p: float = 0.75) -> callable:
14    """The weak augmentations used in the unsupervised data loader.
15
16    Args:
17        p: The probability for applying one of the augmentations.
18
19    Returns:
20        The transformation function applying the augmentation.
21    """
22    norm = torch_em.transform.raw.standardize
23    aug = transforms.Compose([
24        norm,
25        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
26        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
27            scale=(0, 0.15), clip_kwargs=False)], p=p
28        ),
29    ])
30    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)

The weak augmentations used in the unsupervised data loader.

Arguments:
  • p: The probability for applying one of the augmentations.
Returns:

The transformation function applying the augmentation.

def get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], sample_mask_paths: Optional[Tuple[str]] = None, sample_mask_key: Optional[str] = None, bg_mask_paths: Optional[Tuple[str]] = None, bg_mask_key: Optional[str] = None, sampler: Optional[<built-in function callable>] = None, exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
 33def get_unsupervised_loader(
 34    data_paths: Tuple[str],
 35    raw_key: str,
 36    patch_shape: Tuple[int, int, int],
 37    batch_size: int,
 38    n_samples: Optional[int],
 39    sample_mask_paths: Optional[Tuple[str]] = None,
 40    sample_mask_key: Optional[str] = None,
 41    bg_mask_paths: Optional[Tuple[str]] = None,
 42    bg_mask_key: Optional[str] = None,
 43    sampler: Optional[callable] = None,
 44    exclude_top_and_bottom: bool = False,
 45) -> torch.utils.data.DataLoader:
 46    """Get a dataloader for unsupervised segmentation training.
 47
 48    Args:
 49        data_paths: The filepaths to the hdf5 files containing the training data.
 50        raw_key: The key that holds the raw data inside of the hdf5.
 51        patch_shape: The patch shape used for a training example.
 52            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 53            e.g. 'patch_shape = [1, 512, 512]'.
 54        batch_size: The batch size for training.
 55        n_samples: The number of samples per epoch. By default this will be estimated
 56            based on the patch_shape and size of the volumes used for training.
 57        sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
 58        sample_mask_key: The key to the sample mask dataset inside each file.
 59        bg_mask_paths: The filepaths to the background masks for each tomogram.
 60        bg_mask_key: The key to the background mask dataset inside each file.
 61        sampler: Optional sampler to accept or reject patches for training. 
 62        exclude_top_and_bottom: Whether to exclude the five top and bottom slices to
 63            avoid artifacts at the border of tomograms.
 64
 65    Returns:
 66        The PyTorch dataloader.
 67    """
 68    if exclude_top_and_bottom:
 69        roi = (slice(5, -5), slice(None), slice(None))
 70    else:
 71        roi = None
 72
 73    if sample_mask_paths is not None:
 74        assert len(data_paths) == len(sample_mask_paths), \
 75            f"Expected equal number of data_paths and sample_mask_paths, got {len(data_paths)} and {len(sample_mask_paths)}."
 76    if bg_mask_paths is not None:
 77        assert len(data_paths) == len(bg_mask_paths), \
 78            f"Expected equal number of data_paths and bg_mask_paths, got {len(data_paths)} and {len(bg_mask_paths)}."
 79
 80    _, ndim = _determine_ndim(patch_shape)
 81    raw_transform = torch_em.transform.get_raw_transform()
 82    transform = torch_em.transform.get_augmentations(ndim=ndim)
 83    # augmentations = (weak_augmentations(), weak_augmentations())
 84
 85    if n_samples is None:
 86        n_samples_per_ds = None
 87    else:
 88        n_samples_per_ds = int(n_samples / len(data_paths))
 89
 90    datasets = [
 91        RawDatasetWithMasks(
 92            raw_path=data_path,
 93            raw_key=raw_key,
 94            patch_shape=patch_shape,
 95            raw_transform=raw_transform,
 96            transform=transform,
 97            roi=roi,
 98            n_samples=n_samples_per_ds,
 99            sampler=sampler,
100            ndim=ndim,
101            augmentations=None,
102            sample_mask_path=sample_mask_paths[i] if sample_mask_paths is not None else None,
103            sample_mask_key=sample_mask_key,
104            bg_mask_path=bg_mask_paths[i] if bg_mask_paths is not None else None,
105            bg_mask_key=bg_mask_key,
106        )
107        for i, data_path in enumerate(data_paths)
108    ]
109    ds = torch.utils.data.ConcatDataset(datasets)
110
111    num_workers = 4 * batch_size
112    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
113                                                   num_workers=num_workers, shuffle=True)
114    return loader

Get a dataloader for unsupervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
  • sample_mask_key: The key to the sample mask dataset inside each file.
  • bg_mask_paths: The filepaths to the background masks for each tomogram.
  • bg_mask_key: The key to the background mask dataset inside each file.
  • sampler: Optional sampler to accept or reject patches for training.
  • exclude_top_and_bottom: Whether to exclude the five top and bottom slices to avoid artifacts at the border of tomograms.
Returns:

The PyTorch dataloader.

def semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
119def semisupervised_training(
120    name: str,
121    train_paths: Tuple[str],
122    val_paths: Tuple[str],
123    label_key: str,
124    patch_shape: Tuple[int, int, int],
125    save_root: str,
126    raw_key: str = "raw",
127    batch_size: int = 1,
128    lr: float = 1e-4,
129    n_iterations: int = int(1e5),
130    n_samples_train: Optional[int] = None,
131    n_samples_val: Optional[int] = None,
132    check: bool = False,
133) -> None:
134    """Run semi-supervised segmentation training.
135
136    Args:
137        name: The name for the checkpoint to be trained.
138        train_paths: Filepaths to the hdf5 files for the training data.
139        val_paths: Filepaths to the df5 files for the validation data.
140        label_key: The key that holds the labels inside of the hdf5.
141        patch_shape: The patch shape used for a training example.
142            In order to run 2d training pass a patch shape with a singleton in the z-axis,
143            e.g. 'patch_shape = [1, 512, 512]'.
144        save_root: Folder where the checkpoint will be saved.
145        raw_key: The key that holds the raw data inside of the hdf5.
146        batch_size: The batch size for training.
147        lr: The initial learning rate.
148        n_iterations: The number of iterations to train for.
149        n_samples_train: The number of train samples per epoch. By default this will be estimated
150            based on the patch_shape and size of the volumes used for training.
151        n_samples_val: The number of val samples per epoch. By default this will be estimated
152            based on the patch_shape and size of the volumes used for validation.
153        check: Whether to check the training and validation loaders instead of running training.
154    """
155    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
156                                         n_samples=n_samples_train)
157    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
158                                       n_samples=n_samples_val)
159
160    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
161                                                        n_samples=n_samples_train)
162    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
163                                                      n_samples=n_samples_val)
164
165    # TODO check the semisup loader
166    if check:
167        # from torch_em.util.debug import check_loader
168        # check_loader(train_loader, n_samples=4)
169        # check_loader(val_loader, n_samples=4)
170        return
171
172    # Check for 2D or 3D training
173    is_2d = False
174    z, y, x = patch_shape
175    is_2d = z == 1
176
177    if is_2d:
178        model = get_2d_model(out_channels=2)
179    else:
180        model = get_3d_model(out_channels=2)
181
182    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
183    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
184
185    # Self training functionality.
186    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
187    loss = self_training.DefaultSelfTrainingLoss()
188    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
189
190    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
191    trainer = self_training.MeanTeacherTrainer(
192        name=name,
193        model=model,
194        optimizer=optimizer,
195        lr_scheduler=scheduler,
196        pseudo_labeler=pseudo_labeler,
197        unsupervised_loss=loss,
198        unsupervised_loss_and_metric=loss_and_metric,
199        supervised_train_loader=train_loader,
200        unsupervised_train_loader=unsupervised_train_loader,
201        supervised_val_loader=val_loader,
202        unsupervised_val_loader=unsupervised_val_loader,
203        supervised_loss=loss,
204        supervised_loss_and_metric=loss_and_metric,
205        logger=self_training.SelfTrainingTensorboardLogger,
206        mixed_precision=True,
207        device=device,
208        log_image_interval=100,
209        compile_model=False,
210        save_root=save_root,
211    )
212    trainer.fit(n_iterations)

Run semi-supervised segmentation training.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.