synapse_net.training.semisupervised_training
1from typing import Optional, Tuple 2 3import torch 4import torch_em 5import torch_em.self_training as self_training 6from torchvision import transforms 7from torch_em.data import RawDatasetWithMasks 8 9from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim 10 11 12def weak_augmentations(p: float = 0.75) -> callable: 13 """The weak augmentations used in the unsupervised data loader. 14 15 Args: 16 p: The probability for applying one of the augmentations. 17 18 Returns: 19 The transformation function applying the augmentation. 20 """ 21 norm = torch_em.transform.raw.standardize 22 aug = transforms.Compose([ 23 norm, 24 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 25 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 26 scale=(0, 0.15), clip_kwargs=False)], p=p 27 ), 28 ]) 29 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) 30 31 32def get_unsupervised_loader( 33 data_paths: Tuple[str], 34 raw_key: str, 35 patch_shape: Tuple[int, int, int], 36 batch_size: int, 37 n_samples: Optional[int], 38 sample_mask_paths: Optional[Tuple[str]] = None, 39 sample_mask_key: Optional[str] = None, 40 bg_mask_paths: Optional[Tuple[str]] = None, 41 bg_mask_key: Optional[str] = None, 42 sampler: Optional[callable] = None, 43 exclude_top_and_bottom: bool = False, 44) -> torch.utils.data.DataLoader: 45 """Get a dataloader for unsupervised segmentation training. 46 47 Args: 48 data_paths: The filepaths to the hdf5 files containing the training data. 49 raw_key: The key that holds the raw data inside of the hdf5. 50 patch_shape: The patch shape used for a training example. 51 In order to run 2d training pass a patch shape with a singleton in the z-axis, 52 e.g. 'patch_shape = [1, 512, 512]'. 53 batch_size: The batch size for training. 54 n_samples: The number of samples per epoch. By default this will be estimated 55 based on the patch_shape and size of the volumes used for training. 56 sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. 57 sample_mask_key: The key to the sample mask dataset inside each file. 58 bg_mask_paths: The filepaths to the background masks for each tomogram. 59 bg_mask_key: The key to the background mask dataset inside each file. 60 sampler: Optional sampler to accept or reject patches for training. 61 exclude_top_and_bottom: Whether to exclude the five top and bottom slices to 62 avoid artifacts at the border of tomograms. 63 64 Returns: 65 The PyTorch dataloader. 66 """ 67 if exclude_top_and_bottom: 68 roi = (slice(5, -5), slice(None), slice(None)) 69 else: 70 roi = None 71 72 if sample_mask_paths is not None: 73 assert len(data_paths) == len(sample_mask_paths), \ 74 f"Expected equal number of data_paths and sample_mask_paths, got {len(data_paths)} and {len(sample_mask_paths)}." 75 if bg_mask_paths is not None: 76 assert len(data_paths) == len(bg_mask_paths), \ 77 f"Expected equal number of data_paths and bg_mask_paths, got {len(data_paths)} and {len(bg_mask_paths)}." 78 79 _, ndim = _determine_ndim(patch_shape) 80 raw_transform = torch_em.transform.get_raw_transform() 81 transform = torch_em.transform.get_augmentations(ndim=ndim) 82 # augmentations = (weak_augmentations(), weak_augmentations()) 83 84 if n_samples is None: 85 n_samples_per_ds = None 86 else: 87 n_samples_per_ds = int(n_samples / len(data_paths)) 88 89 datasets = [ 90 RawDatasetWithMasks( 91 raw_path=data_path, 92 raw_key=raw_key, 93 patch_shape=patch_shape, 94 raw_transform=raw_transform, 95 transform=transform, 96 roi=roi, 97 n_samples=n_samples_per_ds, 98 sampler=sampler, 99 ndim=ndim, 100 augmentations=None, 101 sample_mask_path=sample_mask_paths[i] if sample_mask_paths is not None else None, 102 sample_mask_key=sample_mask_key, 103 bg_mask_path=bg_mask_paths[i] if bg_mask_paths is not None else None, 104 bg_mask_key=bg_mask_key, 105 ) 106 for i, data_path in enumerate(data_paths) 107 ] 108 ds = torch.utils.data.ConcatDataset(datasets) 109 110 num_workers = 4 * batch_size 111 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, 112 num_workers=num_workers, shuffle=True) 113 return loader 114 115 116# TODO: use different paths for supervised and unsupervised training 117# (We are currently not using this functionality directly, so this is not a high priority) 118def semisupervised_training( 119 name: str, 120 train_paths: Tuple[str], 121 val_paths: Tuple[str], 122 label_key: str, 123 patch_shape: Tuple[int, int, int], 124 save_root: str, 125 raw_key: str = "raw", 126 batch_size: int = 1, 127 lr: float = 1e-4, 128 n_iterations: int = int(1e5), 129 n_samples_train: Optional[int] = None, 130 n_samples_val: Optional[int] = None, 131 check: bool = False, 132) -> None: 133 """Run semi-supervised segmentation training. 134 135 Args: 136 name: The name for the checkpoint to be trained. 137 train_paths: Filepaths to the hdf5 files for the training data. 138 val_paths: Filepaths to the df5 files for the validation data. 139 label_key: The key that holds the labels inside of the hdf5. 140 patch_shape: The patch shape used for a training example. 141 In order to run 2d training pass a patch shape with a singleton in the z-axis, 142 e.g. 'patch_shape = [1, 512, 512]'. 143 save_root: Folder where the checkpoint will be saved. 144 raw_key: The key that holds the raw data inside of the hdf5. 145 batch_size: The batch size for training. 146 lr: The initial learning rate. 147 n_iterations: The number of iterations to train for. 148 n_samples_train: The number of train samples per epoch. By default this will be estimated 149 based on the patch_shape and size of the volumes used for training. 150 n_samples_val: The number of val samples per epoch. By default this will be estimated 151 based on the patch_shape and size of the volumes used for validation. 152 check: Whether to check the training and validation loaders instead of running training. 153 """ 154 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 155 n_samples=n_samples_train) 156 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 157 n_samples=n_samples_val) 158 159 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 160 n_samples=n_samples_train) 161 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 162 n_samples=n_samples_val) 163 164 # TODO check the semisup loader 165 if check: 166 # from torch_em.util.debug import check_loader 167 # check_loader(train_loader, n_samples=4) 168 # check_loader(val_loader, n_samples=4) 169 return 170 171 # Check for 2D or 3D training 172 is_2d = False 173 z, y, x = patch_shape 174 is_2d = z == 1 175 176 if is_2d: 177 model = get_2d_model(out_channels=2) 178 else: 179 model = get_3d_model(out_channels=2) 180 181 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 182 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 183 184 # Self training functionality. 185 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 186 loss = self_training.DefaultSelfTrainingLoss() 187 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 188 189 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 190 trainer = self_training.MeanTeacherTrainer( 191 name=name, 192 model=model, 193 optimizer=optimizer, 194 lr_scheduler=scheduler, 195 pseudo_labeler=pseudo_labeler, 196 unsupervised_loss=loss, 197 unsupervised_loss_and_metric=loss_and_metric, 198 supervised_train_loader=train_loader, 199 unsupervised_train_loader=unsupervised_train_loader, 200 supervised_val_loader=val_loader, 201 unsupervised_val_loader=unsupervised_val_loader, 202 supervised_loss=loss, 203 supervised_loss_and_metric=loss_and_metric, 204 logger=self_training.SelfTrainingTensorboardLogger, 205 mixed_precision=True, 206 device=device, 207 log_image_interval=100, 208 compile_model=False, 209 save_root=save_root, 210 ) 211 trainer.fit(n_iterations)
def
weak_augmentations(p: float = 0.75) -> <built-in function callable>:
13def weak_augmentations(p: float = 0.75) -> callable: 14 """The weak augmentations used in the unsupervised data loader. 15 16 Args: 17 p: The probability for applying one of the augmentations. 18 19 Returns: 20 The transformation function applying the augmentation. 21 """ 22 norm = torch_em.transform.raw.standardize 23 aug = transforms.Compose([ 24 norm, 25 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 26 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 27 scale=(0, 0.15), clip_kwargs=False)], p=p 28 ), 29 ]) 30 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
The weak augmentations used in the unsupervised data loader.
Arguments:
- p: The probability for applying one of the augmentations.
Returns:
The transformation function applying the augmentation.
def
get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], sample_mask_paths: Optional[Tuple[str]] = None, sample_mask_key: Optional[str] = None, bg_mask_paths: Optional[Tuple[str]] = None, bg_mask_key: Optional[str] = None, sampler: Optional[<built-in function callable>] = None, exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
33def get_unsupervised_loader( 34 data_paths: Tuple[str], 35 raw_key: str, 36 patch_shape: Tuple[int, int, int], 37 batch_size: int, 38 n_samples: Optional[int], 39 sample_mask_paths: Optional[Tuple[str]] = None, 40 sample_mask_key: Optional[str] = None, 41 bg_mask_paths: Optional[Tuple[str]] = None, 42 bg_mask_key: Optional[str] = None, 43 sampler: Optional[callable] = None, 44 exclude_top_and_bottom: bool = False, 45) -> torch.utils.data.DataLoader: 46 """Get a dataloader for unsupervised segmentation training. 47 48 Args: 49 data_paths: The filepaths to the hdf5 files containing the training data. 50 raw_key: The key that holds the raw data inside of the hdf5. 51 patch_shape: The patch shape used for a training example. 52 In order to run 2d training pass a patch shape with a singleton in the z-axis, 53 e.g. 'patch_shape = [1, 512, 512]'. 54 batch_size: The batch size for training. 55 n_samples: The number of samples per epoch. By default this will be estimated 56 based on the patch_shape and size of the volumes used for training. 57 sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. 58 sample_mask_key: The key to the sample mask dataset inside each file. 59 bg_mask_paths: The filepaths to the background masks for each tomogram. 60 bg_mask_key: The key to the background mask dataset inside each file. 61 sampler: Optional sampler to accept or reject patches for training. 62 exclude_top_and_bottom: Whether to exclude the five top and bottom slices to 63 avoid artifacts at the border of tomograms. 64 65 Returns: 66 The PyTorch dataloader. 67 """ 68 if exclude_top_and_bottom: 69 roi = (slice(5, -5), slice(None), slice(None)) 70 else: 71 roi = None 72 73 if sample_mask_paths is not None: 74 assert len(data_paths) == len(sample_mask_paths), \ 75 f"Expected equal number of data_paths and sample_mask_paths, got {len(data_paths)} and {len(sample_mask_paths)}." 76 if bg_mask_paths is not None: 77 assert len(data_paths) == len(bg_mask_paths), \ 78 f"Expected equal number of data_paths and bg_mask_paths, got {len(data_paths)} and {len(bg_mask_paths)}." 79 80 _, ndim = _determine_ndim(patch_shape) 81 raw_transform = torch_em.transform.get_raw_transform() 82 transform = torch_em.transform.get_augmentations(ndim=ndim) 83 # augmentations = (weak_augmentations(), weak_augmentations()) 84 85 if n_samples is None: 86 n_samples_per_ds = None 87 else: 88 n_samples_per_ds = int(n_samples / len(data_paths)) 89 90 datasets = [ 91 RawDatasetWithMasks( 92 raw_path=data_path, 93 raw_key=raw_key, 94 patch_shape=patch_shape, 95 raw_transform=raw_transform, 96 transform=transform, 97 roi=roi, 98 n_samples=n_samples_per_ds, 99 sampler=sampler, 100 ndim=ndim, 101 augmentations=None, 102 sample_mask_path=sample_mask_paths[i] if sample_mask_paths is not None else None, 103 sample_mask_key=sample_mask_key, 104 bg_mask_path=bg_mask_paths[i] if bg_mask_paths is not None else None, 105 bg_mask_key=bg_mask_key, 106 ) 107 for i, data_path in enumerate(data_paths) 108 ] 109 ds = torch.utils.data.ConcatDataset(datasets) 110 111 num_workers = 4 * batch_size 112 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, 113 num_workers=num_workers, shuffle=True) 114 return loader
Get a dataloader for unsupervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
- sample_mask_key: The key to the sample mask dataset inside each file.
- bg_mask_paths: The filepaths to the background masks for each tomogram.
- bg_mask_key: The key to the background mask dataset inside each file.
- sampler: Optional sampler to accept or reject patches for training.
- exclude_top_and_bottom: Whether to exclude the five top and bottom slices to avoid artifacts at the border of tomograms.
Returns:
The PyTorch dataloader.
def
semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
119def semisupervised_training( 120 name: str, 121 train_paths: Tuple[str], 122 val_paths: Tuple[str], 123 label_key: str, 124 patch_shape: Tuple[int, int, int], 125 save_root: str, 126 raw_key: str = "raw", 127 batch_size: int = 1, 128 lr: float = 1e-4, 129 n_iterations: int = int(1e5), 130 n_samples_train: Optional[int] = None, 131 n_samples_val: Optional[int] = None, 132 check: bool = False, 133) -> None: 134 """Run semi-supervised segmentation training. 135 136 Args: 137 name: The name for the checkpoint to be trained. 138 train_paths: Filepaths to the hdf5 files for the training data. 139 val_paths: Filepaths to the df5 files for the validation data. 140 label_key: The key that holds the labels inside of the hdf5. 141 patch_shape: The patch shape used for a training example. 142 In order to run 2d training pass a patch shape with a singleton in the z-axis, 143 e.g. 'patch_shape = [1, 512, 512]'. 144 save_root: Folder where the checkpoint will be saved. 145 raw_key: The key that holds the raw data inside of the hdf5. 146 batch_size: The batch size for training. 147 lr: The initial learning rate. 148 n_iterations: The number of iterations to train for. 149 n_samples_train: The number of train samples per epoch. By default this will be estimated 150 based on the patch_shape and size of the volumes used for training. 151 n_samples_val: The number of val samples per epoch. By default this will be estimated 152 based on the patch_shape and size of the volumes used for validation. 153 check: Whether to check the training and validation loaders instead of running training. 154 """ 155 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 156 n_samples=n_samples_train) 157 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 158 n_samples=n_samples_val) 159 160 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 161 n_samples=n_samples_train) 162 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 163 n_samples=n_samples_val) 164 165 # TODO check the semisup loader 166 if check: 167 # from torch_em.util.debug import check_loader 168 # check_loader(train_loader, n_samples=4) 169 # check_loader(val_loader, n_samples=4) 170 return 171 172 # Check for 2D or 3D training 173 is_2d = False 174 z, y, x = patch_shape 175 is_2d = z == 1 176 177 if is_2d: 178 model = get_2d_model(out_channels=2) 179 else: 180 model = get_3d_model(out_channels=2) 181 182 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 183 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 184 185 # Self training functionality. 186 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 187 loss = self_training.DefaultSelfTrainingLoss() 188 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 189 190 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 191 trainer = self_training.MeanTeacherTrainer( 192 name=name, 193 model=model, 194 optimizer=optimizer, 195 lr_scheduler=scheduler, 196 pseudo_labeler=pseudo_labeler, 197 unsupervised_loss=loss, 198 unsupervised_loss_and_metric=loss_and_metric, 199 supervised_train_loader=train_loader, 200 unsupervised_train_loader=unsupervised_train_loader, 201 supervised_val_loader=val_loader, 202 unsupervised_val_loader=unsupervised_val_loader, 203 supervised_loss=loss, 204 supervised_loss_and_metric=loss_and_metric, 205 logger=self_training.SelfTrainingTensorboardLogger, 206 mixed_precision=True, 207 device=device, 208 log_image_interval=100, 209 compile_model=False, 210 save_root=save_root, 211 ) 212 trainer.fit(n_iterations)
Run semi-supervised segmentation training.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.