synapse_net.training.semisupervised_training

  1from typing import Optional, Tuple
  2
  3import numpy as np
  4import uuid
  5import h5py
  6import torch
  7import torch_em
  8import torch_em.self_training as self_training
  9from torchvision import transforms
 10
 11from synapse_net.file_utils import read_mrc
 12from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim
 13
 14
 15def weak_augmentations(p: float = 0.75) -> callable:
 16    """The weak augmentations used in the unsupervised data loader.
 17
 18    Args:
 19        p: The probability for applying one of the augmentations.
 20
 21    Returns:
 22        The transformation function applying the augmentation.
 23    """
 24    norm = torch_em.transform.raw.standardize
 25    aug = transforms.Compose([
 26        norm,
 27        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
 28        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
 29            scale=(0, 0.15), clip_kwargs=False)], p=p
 30        ),
 31    ])
 32    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
 33
 34def drop_mask_channel(x):
 35    x = x[:1]
 36    return x
 37    
 38class ComposedTransform:
 39    def __init__(self, *funcs):
 40        self.funcs = funcs
 41
 42    def __call__(self, x):
 43        for f in self.funcs:
 44            x = f(x)
 45        return x
 46
 47class ChannelSplitterSampler: 
 48    def __init__(self, sampler):
 49        self.sampler = sampler
 50    
 51    def __call__(self, x):
 52        raw, mask = x[0], x[1]
 53        return self.sampler(raw, mask)
 54
 55def get_unsupervised_loader(
 56    data_paths: Tuple[str],
 57    raw_key: str,
 58    patch_shape: Tuple[int, int, int],
 59    batch_size: int,
 60    n_samples: Optional[int],
 61    sample_mask_paths: Optional[Tuple[str]] = None,
 62    sampler: Optional[callable] = None,
 63    exclude_top_and_bottom: bool = False, 
 64) -> torch.utils.data.DataLoader:
 65    """Get a dataloader for unsupervised segmentation training.
 66
 67    Args:
 68        data_paths: The filepaths to the hdf5 files containing the training data.
 69        raw_key: The key that holds the raw data inside of the hdf5.
 70        patch_shape: The patch shape used for a training example.
 71            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 72            e.g. 'patch_shape = [1, 512, 512]'.
 73        batch_size: The batch size for training.
 74        n_samples: The number of samples per epoch. By default this will be estimated
 75            based on the patch_shape and size of the volumes used for training.
 76        exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
 77            avoid artifacts at the border of tomograms.
 78        sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
 79        sampler: Accept or reject patches based on a condition.
 80
 81    Returns:
 82        The PyTorch dataloader.
 83    """
 84    # We exclude the top and bottom slices where the tomogram reconstruction is bad.
 85    # TODO this seems unneccesary if we have a boundary mask - remove? 
 86    if exclude_top_and_bottom:
 87        roi = np.s_[5:-5, :, :]
 88    else:
 89        roi = None
 90    # stack tomograms and masks and write to temp files to use as input to RawDataset()    
 91    if sample_mask_paths is not None:
 92        assert len(data_paths) == len(sample_mask_paths), \
 93            f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."
 94        
 95        stacked_paths = []
 96        for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
 97            raw = read_mrc(data_path)[0]
 98            mask = read_mrc(mask_path)[0]
 99            stacked = np.stack([raw, mask], axis=0)
100
101            tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
102            with h5py.File(tmp_path, "w") as f:
103                f.create_dataset("raw", data=stacked, compression="gzip")
104            stacked_paths.append(tmp_path)
105
106        # update variables for RawDataset()
107        data_paths = tuple(stacked_paths)    
108        base_transform = torch_em.transform.get_raw_transform()
109        raw_transform = ComposedTransform(base_transform, drop_mask_channel)
110        sampler = ChannelSplitterSampler(sampler)
111        with_channels = True 
112    else:
113        raw_transform = torch_em.transform.get_raw_transform()
114        with_channels = False
115        sampler = None
116
117    _, ndim = _determine_ndim(patch_shape)
118    transform = torch_em.transform.get_augmentations(ndim=ndim)
119
120    if n_samples is None:
121        n_samples_per_ds = None
122    else:
123        n_samples_per_ds = int(n_samples / len(data_paths))
124
125    augmentations = (weak_augmentations(), weak_augmentations())
126
127    datasets = [
128        torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
129        n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
130        for path in data_paths
131    ]
132    ds = torch.utils.data.ConcatDataset(datasets)
133
134    num_workers = 4 * batch_size
135    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
136                                                   num_workers=num_workers, shuffle=True)
137    return loader
138
139
140# TODO: use different paths for supervised and unsupervised training
141# (We are currently not using this functionality directly, so this is not a high priority)
142def semisupervised_training(
143    name: str,
144    train_paths: Tuple[str],
145    val_paths: Tuple[str],
146    label_key: str,
147    patch_shape: Tuple[int, int, int],
148    save_root: str,
149    raw_key: str = "raw",
150    batch_size: int = 1,
151    lr: float = 1e-4,
152    n_iterations: int = int(1e5),
153    n_samples_train: Optional[int] = None,
154    n_samples_val: Optional[int] = None,
155    check: bool = False,
156) -> None:
157    """Run semi-supervised segmentation training.
158
159    Args:
160        name: The name for the checkpoint to be trained.
161        train_paths: Filepaths to the hdf5 files for the training data.
162        val_paths: Filepaths to the df5 files for the validation data.
163        label_key: The key that holds the labels inside of the hdf5.
164        patch_shape: The patch shape used for a training example.
165            In order to run 2d training pass a patch shape with a singleton in the z-axis,
166            e.g. 'patch_shape = [1, 512, 512]'.
167        save_root: Folder where the checkpoint will be saved.
168        raw_key: The key that holds the raw data inside of the hdf5.
169        batch_size: The batch size for training.
170        lr: The initial learning rate.
171        n_iterations: The number of iterations to train for.
172        n_samples_train: The number of train samples per epoch. By default this will be estimated
173            based on the patch_shape and size of the volumes used for training.
174        n_samples_val: The number of val samples per epoch. By default this will be estimated
175            based on the patch_shape and size of the volumes used for validation.
176        check: Whether to check the training and validation loaders instead of running training.
177    """
178    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
179                                         n_samples=n_samples_train)
180    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
181                                       n_samples=n_samples_val)
182
183    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
184                                                        n_samples=n_samples_train)
185    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
186                                                      n_samples=n_samples_val)
187
188    # TODO check the semisup loader
189    if check:
190        # from torch_em.util.debug import check_loader
191        # check_loader(train_loader, n_samples=4)
192        # check_loader(val_loader, n_samples=4)
193        return
194
195    # Check for 2D or 3D training
196    is_2d = False
197    z, y, x = patch_shape
198    is_2d = z == 1
199
200    if is_2d:
201        model = get_2d_model(out_channels=2)
202    else:
203        model = get_3d_model(out_channels=2)
204
205    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
206    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
207
208    # Self training functionality.
209    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
210    loss = self_training.DefaultSelfTrainingLoss()
211    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
212
213    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214    trainer = self_training.MeanTeacherTrainer(
215        name=name,
216        model=model,
217        optimizer=optimizer,
218        lr_scheduler=scheduler,
219        pseudo_labeler=pseudo_labeler,
220        unsupervised_loss=loss,
221        unsupervised_loss_and_metric=loss_and_metric,
222        supervised_train_loader=train_loader,
223        unsupervised_train_loader=unsupervised_train_loader,
224        supervised_val_loader=val_loader,
225        unsupervised_val_loader=unsupervised_val_loader,
226        supervised_loss=loss,
227        supervised_loss_and_metric=loss_and_metric,
228        logger=self_training.SelfTrainingTensorboardLogger,
229        mixed_precision=True,
230        device=device,
231        log_image_interval=100,
232        compile_model=False,
233        save_root=save_root,
234    )
235    trainer.fit(n_iterations)
def weak_augmentations(p: float = 0.75) -> <built-in function callable>:
16def weak_augmentations(p: float = 0.75) -> callable:
17    """The weak augmentations used in the unsupervised data loader.
18
19    Args:
20        p: The probability for applying one of the augmentations.
21
22    Returns:
23        The transformation function applying the augmentation.
24    """
25    norm = torch_em.transform.raw.standardize
26    aug = transforms.Compose([
27        norm,
28        transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p),
29        transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise(
30            scale=(0, 0.15), clip_kwargs=False)], p=p
31        ),
32    ])
33    return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)

The weak augmentations used in the unsupervised data loader.

Arguments:
  • p: The probability for applying one of the augmentations.
Returns:

The transformation function applying the augmentation.

def drop_mask_channel(x):
35def drop_mask_channel(x):
36    x = x[:1]
37    return x
class ComposedTransform:
39class ComposedTransform:
40    def __init__(self, *funcs):
41        self.funcs = funcs
42
43    def __call__(self, x):
44        for f in self.funcs:
45            x = f(x)
46        return x
ComposedTransform(*funcs)
40    def __init__(self, *funcs):
41        self.funcs = funcs
funcs
class ChannelSplitterSampler:
48class ChannelSplitterSampler: 
49    def __init__(self, sampler):
50        self.sampler = sampler
51    
52    def __call__(self, x):
53        raw, mask = x[0], x[1]
54        return self.sampler(raw, mask)
ChannelSplitterSampler(sampler)
49    def __init__(self, sampler):
50        self.sampler = sampler
sampler
def get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], sample_mask_paths: Optional[Tuple[str]] = None, sampler: Optional[<built-in function callable>] = None, exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
 56def get_unsupervised_loader(
 57    data_paths: Tuple[str],
 58    raw_key: str,
 59    patch_shape: Tuple[int, int, int],
 60    batch_size: int,
 61    n_samples: Optional[int],
 62    sample_mask_paths: Optional[Tuple[str]] = None,
 63    sampler: Optional[callable] = None,
 64    exclude_top_and_bottom: bool = False, 
 65) -> torch.utils.data.DataLoader:
 66    """Get a dataloader for unsupervised segmentation training.
 67
 68    Args:
 69        data_paths: The filepaths to the hdf5 files containing the training data.
 70        raw_key: The key that holds the raw data inside of the hdf5.
 71        patch_shape: The patch shape used for a training example.
 72            In order to run 2d training pass a patch shape with a singleton in the z-axis,
 73            e.g. 'patch_shape = [1, 512, 512]'.
 74        batch_size: The batch size for training.
 75        n_samples: The number of samples per epoch. By default this will be estimated
 76            based on the patch_shape and size of the volumes used for training.
 77        exclude_top_and_bottom: Whether to exluce the five top and bottom slices to
 78            avoid artifacts at the border of tomograms.
 79        sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
 80        sampler: Accept or reject patches based on a condition.
 81
 82    Returns:
 83        The PyTorch dataloader.
 84    """
 85    # We exclude the top and bottom slices where the tomogram reconstruction is bad.
 86    # TODO this seems unneccesary if we have a boundary mask - remove? 
 87    if exclude_top_and_bottom:
 88        roi = np.s_[5:-5, :, :]
 89    else:
 90        roi = None
 91    # stack tomograms and masks and write to temp files to use as input to RawDataset()    
 92    if sample_mask_paths is not None:
 93        assert len(data_paths) == len(sample_mask_paths), \
 94            f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths."
 95        
 96        stacked_paths = []
 97        for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)):
 98            raw = read_mrc(data_path)[0]
 99            mask = read_mrc(mask_path)[0]
100            stacked = np.stack([raw, mask], axis=0)
101
102            tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5"
103            with h5py.File(tmp_path, "w") as f:
104                f.create_dataset("raw", data=stacked, compression="gzip")
105            stacked_paths.append(tmp_path)
106
107        # update variables for RawDataset()
108        data_paths = tuple(stacked_paths)    
109        base_transform = torch_em.transform.get_raw_transform()
110        raw_transform = ComposedTransform(base_transform, drop_mask_channel)
111        sampler = ChannelSplitterSampler(sampler)
112        with_channels = True 
113    else:
114        raw_transform = torch_em.transform.get_raw_transform()
115        with_channels = False
116        sampler = None
117
118    _, ndim = _determine_ndim(patch_shape)
119    transform = torch_em.transform.get_augmentations(ndim=ndim)
120
121    if n_samples is None:
122        n_samples_per_ds = None
123    else:
124        n_samples_per_ds = int(n_samples / len(data_paths))
125
126    augmentations = (weak_augmentations(), weak_augmentations())
127
128    datasets = [
129        torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi,
130        n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations)
131        for path in data_paths
132    ]
133    ds = torch.utils.data.ConcatDataset(datasets)
134
135    num_workers = 4 * batch_size
136    loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size,
137                                                   num_workers=num_workers, shuffle=True)
138    return loader

Get a dataloader for unsupervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • exclude_top_and_bottom: Whether to exluce the five top and bottom slices to avoid artifacts at the border of tomograms.
  • sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
  • sampler: Accept or reject patches based on a condition.
Returns:

The PyTorch dataloader.

def semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
143def semisupervised_training(
144    name: str,
145    train_paths: Tuple[str],
146    val_paths: Tuple[str],
147    label_key: str,
148    patch_shape: Tuple[int, int, int],
149    save_root: str,
150    raw_key: str = "raw",
151    batch_size: int = 1,
152    lr: float = 1e-4,
153    n_iterations: int = int(1e5),
154    n_samples_train: Optional[int] = None,
155    n_samples_val: Optional[int] = None,
156    check: bool = False,
157) -> None:
158    """Run semi-supervised segmentation training.
159
160    Args:
161        name: The name for the checkpoint to be trained.
162        train_paths: Filepaths to the hdf5 files for the training data.
163        val_paths: Filepaths to the df5 files for the validation data.
164        label_key: The key that holds the labels inside of the hdf5.
165        patch_shape: The patch shape used for a training example.
166            In order to run 2d training pass a patch shape with a singleton in the z-axis,
167            e.g. 'patch_shape = [1, 512, 512]'.
168        save_root: Folder where the checkpoint will be saved.
169        raw_key: The key that holds the raw data inside of the hdf5.
170        batch_size: The batch size for training.
171        lr: The initial learning rate.
172        n_iterations: The number of iterations to train for.
173        n_samples_train: The number of train samples per epoch. By default this will be estimated
174            based on the patch_shape and size of the volumes used for training.
175        n_samples_val: The number of val samples per epoch. By default this will be estimated
176            based on the patch_shape and size of the volumes used for validation.
177        check: Whether to check the training and validation loaders instead of running training.
178    """
179    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
180                                         n_samples=n_samples_train)
181    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
182                                       n_samples=n_samples_val)
183
184    unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size,
185                                                        n_samples=n_samples_train)
186    unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size,
187                                                      n_samples=n_samples_val)
188
189    # TODO check the semisup loader
190    if check:
191        # from torch_em.util.debug import check_loader
192        # check_loader(train_loader, n_samples=4)
193        # check_loader(val_loader, n_samples=4)
194        return
195
196    # Check for 2D or 3D training
197    is_2d = False
198    z, y, x = patch_shape
199    is_2d = z == 1
200
201    if is_2d:
202        model = get_2d_model(out_channels=2)
203    else:
204        model = get_3d_model(out_channels=2)
205
206    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
207    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)
208
209    # Self training functionality.
210    pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9)
211    loss = self_training.DefaultSelfTrainingLoss()
212    loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric()
213
214    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
215    trainer = self_training.MeanTeacherTrainer(
216        name=name,
217        model=model,
218        optimizer=optimizer,
219        lr_scheduler=scheduler,
220        pseudo_labeler=pseudo_labeler,
221        unsupervised_loss=loss,
222        unsupervised_loss_and_metric=loss_and_metric,
223        supervised_train_loader=train_loader,
224        unsupervised_train_loader=unsupervised_train_loader,
225        supervised_val_loader=val_loader,
226        unsupervised_val_loader=unsupervised_val_loader,
227        supervised_loss=loss,
228        supervised_loss_and_metric=loss_and_metric,
229        logger=self_training.SelfTrainingTensorboardLogger,
230        mixed_precision=True,
231        device=device,
232        log_image_interval=100,
233        compile_model=False,
234        save_root=save_root,
235    )
236    trainer.fit(n_iterations)

Run semi-supervised segmentation training.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.