synapse_net.training.semisupervised_training
1from typing import Optional, Tuple 2 3import numpy as np 4import torch 5import torch_em 6import torch_em.self_training as self_training 7from torchvision import transforms 8 9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim 10 11 12def weak_augmentations(p: float = 0.75) -> callable: 13 """The weak augmentations used in the unsupervised data loader. 14 15 Args: 16 p: The probability for applying one of the augmentations. 17 18 Returns: 19 The transformation function applying the augmentation. 20 """ 21 norm = torch_em.transform.raw.standardize 22 aug = transforms.Compose([ 23 norm, 24 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 25 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 26 scale=(0, 0.15), clip_kwargs=False)], p=p 27 ), 28 ]) 29 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) 30 31 32def get_unsupervised_loader( 33 data_paths: Tuple[str], 34 raw_key: str, 35 patch_shape: Tuple[int, int, int], 36 batch_size: int, 37 n_samples: Optional[int], 38 exclude_top_and_bottom: bool = False, 39) -> torch.utils.data.DataLoader: 40 """Get a dataloader for unsupervised segmentation training. 41 42 Args: 43 data_paths: The filepaths to the hdf5 files containing the training data. 44 raw_key: The key that holds the raw data inside of the hdf5. 45 patch_shape: The patch shape used for a training example. 46 In order to run 2d training pass a patch shape with a singleton in the z-axis, 47 e.g. 'patch_shape = [1, 512, 512]'. 48 batch_size: The batch size for training. 49 n_samples: The number of samples per epoch. By default this will be estimated 50 based on the patch_shape and size of the volumes used for training. 51 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to 52 avoid artifacts at the border of tomograms. 53 54 Returns: 55 The PyTorch dataloader. 56 """ 57 58 # We exclude the top and bottom slices where the tomogram reconstruction is bad. 59 if exclude_top_and_bottom: 60 roi = np.s_[5:-5, :, :] 61 else: 62 roi = None 63 64 _, ndim = _determine_ndim(patch_shape) 65 raw_transform = torch_em.transform.get_raw_transform() 66 transform = torch_em.transform.get_augmentations(ndim=ndim) 67 68 if n_samples is None: 69 n_samples_per_ds = None 70 else: 71 n_samples_per_ds = int(n_samples / len(data_paths)) 72 73 augmentations = (weak_augmentations(), weak_augmentations()) 74 datasets = [ 75 torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, 76 augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds) 77 for path in data_paths 78 ] 79 ds = torch.utils.data.ConcatDataset(datasets) 80 81 num_workers = 4 * batch_size 82 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) 83 return loader 84 85 86# TODO: use different paths for supervised and unsupervised training 87# (We are currently not using this functionality directly, so this is not a high priority) 88def semisupervised_training( 89 name: str, 90 train_paths: Tuple[str], 91 val_paths: Tuple[str], 92 label_key: str, 93 patch_shape: Tuple[int, int, int], 94 save_root: str, 95 raw_key: str = "raw", 96 batch_size: int = 1, 97 lr: float = 1e-4, 98 n_iterations: int = int(1e5), 99 n_samples_train: Optional[int] = None, 100 n_samples_val: Optional[int] = None, 101 check: bool = False, 102) -> None: 103 """Run semi-supervised segmentation training. 104 105 Args: 106 name: The name for the checkpoint to be trained. 107 train_paths: Filepaths to the hdf5 files for the training data. 108 val_paths: Filepaths to the df5 files for the validation data. 109 label_key: The key that holds the labels inside of the hdf5. 110 patch_shape: The patch shape used for a training example. 111 In order to run 2d training pass a patch shape with a singleton in the z-axis, 112 e.g. 'patch_shape = [1, 512, 512]'. 113 save_root: Folder where the checkpoint will be saved. 114 raw_key: The key that holds the raw data inside of the hdf5. 115 batch_size: The batch size for training. 116 lr: The initial learning rate. 117 n_iterations: The number of iterations to train for. 118 n_samples_train: The number of train samples per epoch. By default this will be estimated 119 based on the patch_shape and size of the volumes used for training. 120 n_samples_val: The number of val samples per epoch. By default this will be estimated 121 based on the patch_shape and size of the volumes used for validation. 122 check: Whether to check the training and validation loaders instead of running training. 123 """ 124 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 125 n_samples=n_samples_train) 126 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 127 n_samples=n_samples_val) 128 129 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 130 n_samples=n_samples_train) 131 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 132 n_samples=n_samples_val) 133 134 # TODO check the semisup loader 135 if check: 136 # from torch_em.util.debug import check_loader 137 # check_loader(train_loader, n_samples=4) 138 # check_loader(val_loader, n_samples=4) 139 return 140 141 # Check for 2D or 3D training 142 is_2d = False 143 z, y, x = patch_shape 144 is_2d = z == 1 145 146 if is_2d: 147 model = get_2d_model(out_channels=2) 148 else: 149 model = get_3d_model(out_channels=2) 150 151 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 152 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 153 154 # Self training functionality. 155 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 156 loss = self_training.DefaultSelfTrainingLoss() 157 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 158 159 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 160 trainer = self_training.MeanTeacherTrainer( 161 name=name, 162 model=model, 163 optimizer=optimizer, 164 lr_scheduler=scheduler, 165 pseudo_labeler=pseudo_labeler, 166 unsupervised_loss=loss, 167 unsupervised_loss_and_metric=loss_and_metric, 168 supervised_train_loader=train_loader, 169 unsupervised_train_loader=unsupervised_train_loader, 170 supervised_val_loader=val_loader, 171 unsupervised_val_loader=unsupervised_val_loader, 172 supervised_loss=loss, 173 supervised_loss_and_metric=loss_and_metric, 174 logger=self_training.SelfTrainingTensorboardLogger, 175 mixed_precision=True, 176 device=device, 177 log_image_interval=100, 178 compile_model=False, 179 save_root=save_root, 180 ) 181 trainer.fit(n_iterations)
def
weak_augmentations(p: float = 0.75) -> <built-in function callable>:
13def weak_augmentations(p: float = 0.75) -> callable: 14 """The weak augmentations used in the unsupervised data loader. 15 16 Args: 17 p: The probability for applying one of the augmentations. 18 19 Returns: 20 The transformation function applying the augmentation. 21 """ 22 norm = torch_em.transform.raw.standardize 23 aug = transforms.Compose([ 24 norm, 25 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 26 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 27 scale=(0, 0.15), clip_kwargs=False)], p=p 28 ), 29 ]) 30 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
The weak augmentations used in the unsupervised data loader.
Arguments:
- p: The probability for applying one of the augmentations.
Returns:
The transformation function applying the augmentation.
def
get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
33def get_unsupervised_loader( 34 data_paths: Tuple[str], 35 raw_key: str, 36 patch_shape: Tuple[int, int, int], 37 batch_size: int, 38 n_samples: Optional[int], 39 exclude_top_and_bottom: bool = False, 40) -> torch.utils.data.DataLoader: 41 """Get a dataloader for unsupervised segmentation training. 42 43 Args: 44 data_paths: The filepaths to the hdf5 files containing the training data. 45 raw_key: The key that holds the raw data inside of the hdf5. 46 patch_shape: The patch shape used for a training example. 47 In order to run 2d training pass a patch shape with a singleton in the z-axis, 48 e.g. 'patch_shape = [1, 512, 512]'. 49 batch_size: The batch size for training. 50 n_samples: The number of samples per epoch. By default this will be estimated 51 based on the patch_shape and size of the volumes used for training. 52 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to 53 avoid artifacts at the border of tomograms. 54 55 Returns: 56 The PyTorch dataloader. 57 """ 58 59 # We exclude the top and bottom slices where the tomogram reconstruction is bad. 60 if exclude_top_and_bottom: 61 roi = np.s_[5:-5, :, :] 62 else: 63 roi = None 64 65 _, ndim = _determine_ndim(patch_shape) 66 raw_transform = torch_em.transform.get_raw_transform() 67 transform = torch_em.transform.get_augmentations(ndim=ndim) 68 69 if n_samples is None: 70 n_samples_per_ds = None 71 else: 72 n_samples_per_ds = int(n_samples / len(data_paths)) 73 74 augmentations = (weak_augmentations(), weak_augmentations()) 75 datasets = [ 76 torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, 77 augmentations=augmentations, roi=roi, ndim=ndim, n_samples=n_samples_per_ds) 78 for path in data_paths 79 ] 80 ds = torch.utils.data.ConcatDataset(datasets) 81 82 num_workers = 4 * batch_size 83 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, num_workers=num_workers, shuffle=True) 84 return loader
Get a dataloader for unsupervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- exclude_top_and_bottom: Whether to exluce the five top and bottom slices to avoid artifacts at the border of tomograms.
Returns:
The PyTorch dataloader.
def
semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
89def semisupervised_training( 90 name: str, 91 train_paths: Tuple[str], 92 val_paths: Tuple[str], 93 label_key: str, 94 patch_shape: Tuple[int, int, int], 95 save_root: str, 96 raw_key: str = "raw", 97 batch_size: int = 1, 98 lr: float = 1e-4, 99 n_iterations: int = int(1e5), 100 n_samples_train: Optional[int] = None, 101 n_samples_val: Optional[int] = None, 102 check: bool = False, 103) -> None: 104 """Run semi-supervised segmentation training. 105 106 Args: 107 name: The name for the checkpoint to be trained. 108 train_paths: Filepaths to the hdf5 files for the training data. 109 val_paths: Filepaths to the df5 files for the validation data. 110 label_key: The key that holds the labels inside of the hdf5. 111 patch_shape: The patch shape used for a training example. 112 In order to run 2d training pass a patch shape with a singleton in the z-axis, 113 e.g. 'patch_shape = [1, 512, 512]'. 114 save_root: Folder where the checkpoint will be saved. 115 raw_key: The key that holds the raw data inside of the hdf5. 116 batch_size: The batch size for training. 117 lr: The initial learning rate. 118 n_iterations: The number of iterations to train for. 119 n_samples_train: The number of train samples per epoch. By default this will be estimated 120 based on the patch_shape and size of the volumes used for training. 121 n_samples_val: The number of val samples per epoch. By default this will be estimated 122 based on the patch_shape and size of the volumes used for validation. 123 check: Whether to check the training and validation loaders instead of running training. 124 """ 125 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 126 n_samples=n_samples_train) 127 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 128 n_samples=n_samples_val) 129 130 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 131 n_samples=n_samples_train) 132 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 133 n_samples=n_samples_val) 134 135 # TODO check the semisup loader 136 if check: 137 # from torch_em.util.debug import check_loader 138 # check_loader(train_loader, n_samples=4) 139 # check_loader(val_loader, n_samples=4) 140 return 141 142 # Check for 2D or 3D training 143 is_2d = False 144 z, y, x = patch_shape 145 is_2d = z == 1 146 147 if is_2d: 148 model = get_2d_model(out_channels=2) 149 else: 150 model = get_3d_model(out_channels=2) 151 152 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 153 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 154 155 # Self training functionality. 156 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 157 loss = self_training.DefaultSelfTrainingLoss() 158 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 159 160 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 161 trainer = self_training.MeanTeacherTrainer( 162 name=name, 163 model=model, 164 optimizer=optimizer, 165 lr_scheduler=scheduler, 166 pseudo_labeler=pseudo_labeler, 167 unsupervised_loss=loss, 168 unsupervised_loss_and_metric=loss_and_metric, 169 supervised_train_loader=train_loader, 170 unsupervised_train_loader=unsupervised_train_loader, 171 supervised_val_loader=val_loader, 172 unsupervised_val_loader=unsupervised_val_loader, 173 supervised_loss=loss, 174 supervised_loss_and_metric=loss_and_metric, 175 logger=self_training.SelfTrainingTensorboardLogger, 176 mixed_precision=True, 177 device=device, 178 log_image_interval=100, 179 compile_model=False, 180 save_root=save_root, 181 ) 182 trainer.fit(n_iterations)
Run semi-supervised segmentation training.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.