synapse_net.training.semisupervised_training
1from typing import Optional, Tuple 2 3import numpy as np 4import uuid 5import h5py 6import torch 7import torch_em 8import torch_em.self_training as self_training 9from torchvision import transforms 10 11from synapse_net.file_utils import read_mrc 12from .supervised_training import get_2d_model, get_3d_model, get_supervised_loader, _determine_ndim 13 14 15def weak_augmentations(p: float = 0.75) -> callable: 16 """The weak augmentations used in the unsupervised data loader. 17 18 Args: 19 p: The probability for applying one of the augmentations. 20 21 Returns: 22 The transformation function applying the augmentation. 23 """ 24 norm = torch_em.transform.raw.standardize 25 aug = transforms.Compose([ 26 norm, 27 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 28 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 29 scale=(0, 0.15), clip_kwargs=False)], p=p 30 ), 31 ]) 32 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug) 33 34def drop_mask_channel(x): 35 x = x[:1] 36 return x 37 38class ComposedTransform: 39 def __init__(self, *funcs): 40 self.funcs = funcs 41 42 def __call__(self, x): 43 for f in self.funcs: 44 x = f(x) 45 return x 46 47class ChannelSplitterSampler: 48 def __init__(self, sampler): 49 self.sampler = sampler 50 51 def __call__(self, x): 52 raw, mask = x[0], x[1] 53 return self.sampler(raw, mask) 54 55def get_unsupervised_loader( 56 data_paths: Tuple[str], 57 raw_key: str, 58 patch_shape: Tuple[int, int, int], 59 batch_size: int, 60 n_samples: Optional[int], 61 sample_mask_paths: Optional[Tuple[str]] = None, 62 sampler: Optional[callable] = None, 63 exclude_top_and_bottom: bool = False, 64) -> torch.utils.data.DataLoader: 65 """Get a dataloader for unsupervised segmentation training. 66 67 Args: 68 data_paths: The filepaths to the hdf5 files containing the training data. 69 raw_key: The key that holds the raw data inside of the hdf5. 70 patch_shape: The patch shape used for a training example. 71 In order to run 2d training pass a patch shape with a singleton in the z-axis, 72 e.g. 'patch_shape = [1, 512, 512]'. 73 batch_size: The batch size for training. 74 n_samples: The number of samples per epoch. By default this will be estimated 75 based on the patch_shape and size of the volumes used for training. 76 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to 77 avoid artifacts at the border of tomograms. 78 sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. 79 sampler: Accept or reject patches based on a condition. 80 81 Returns: 82 The PyTorch dataloader. 83 """ 84 # We exclude the top and bottom slices where the tomogram reconstruction is bad. 85 # TODO this seems unneccesary if we have a boundary mask - remove? 86 if exclude_top_and_bottom: 87 roi = np.s_[5:-5, :, :] 88 else: 89 roi = None 90 # stack tomograms and masks and write to temp files to use as input to RawDataset() 91 if sample_mask_paths is not None: 92 assert len(data_paths) == len(sample_mask_paths), \ 93 f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths." 94 95 stacked_paths = [] 96 for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)): 97 raw = read_mrc(data_path)[0] 98 mask = read_mrc(mask_path)[0] 99 stacked = np.stack([raw, mask], axis=0) 100 101 tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5" 102 with h5py.File(tmp_path, "w") as f: 103 f.create_dataset("raw", data=stacked, compression="gzip") 104 stacked_paths.append(tmp_path) 105 106 # update variables for RawDataset() 107 data_paths = tuple(stacked_paths) 108 base_transform = torch_em.transform.get_raw_transform() 109 raw_transform = ComposedTransform(base_transform, drop_mask_channel) 110 sampler = ChannelSplitterSampler(sampler) 111 with_channels = True 112 else: 113 raw_transform = torch_em.transform.get_raw_transform() 114 with_channels = False 115 sampler = None 116 117 _, ndim = _determine_ndim(patch_shape) 118 transform = torch_em.transform.get_augmentations(ndim=ndim) 119 120 if n_samples is None: 121 n_samples_per_ds = None 122 else: 123 n_samples_per_ds = int(n_samples / len(data_paths)) 124 125 augmentations = (weak_augmentations(), weak_augmentations()) 126 127 datasets = [ 128 torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi, 129 n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations) 130 for path in data_paths 131 ] 132 ds = torch.utils.data.ConcatDataset(datasets) 133 134 num_workers = 4 * batch_size 135 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, 136 num_workers=num_workers, shuffle=True) 137 return loader 138 139 140# TODO: use different paths for supervised and unsupervised training 141# (We are currently not using this functionality directly, so this is not a high priority) 142def semisupervised_training( 143 name: str, 144 train_paths: Tuple[str], 145 val_paths: Tuple[str], 146 label_key: str, 147 patch_shape: Tuple[int, int, int], 148 save_root: str, 149 raw_key: str = "raw", 150 batch_size: int = 1, 151 lr: float = 1e-4, 152 n_iterations: int = int(1e5), 153 n_samples_train: Optional[int] = None, 154 n_samples_val: Optional[int] = None, 155 check: bool = False, 156) -> None: 157 """Run semi-supervised segmentation training. 158 159 Args: 160 name: The name for the checkpoint to be trained. 161 train_paths: Filepaths to the hdf5 files for the training data. 162 val_paths: Filepaths to the df5 files for the validation data. 163 label_key: The key that holds the labels inside of the hdf5. 164 patch_shape: The patch shape used for a training example. 165 In order to run 2d training pass a patch shape with a singleton in the z-axis, 166 e.g. 'patch_shape = [1, 512, 512]'. 167 save_root: Folder where the checkpoint will be saved. 168 raw_key: The key that holds the raw data inside of the hdf5. 169 batch_size: The batch size for training. 170 lr: The initial learning rate. 171 n_iterations: The number of iterations to train for. 172 n_samples_train: The number of train samples per epoch. By default this will be estimated 173 based on the patch_shape and size of the volumes used for training. 174 n_samples_val: The number of val samples per epoch. By default this will be estimated 175 based on the patch_shape and size of the volumes used for validation. 176 check: Whether to check the training and validation loaders instead of running training. 177 """ 178 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 179 n_samples=n_samples_train) 180 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 181 n_samples=n_samples_val) 182 183 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 184 n_samples=n_samples_train) 185 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 186 n_samples=n_samples_val) 187 188 # TODO check the semisup loader 189 if check: 190 # from torch_em.util.debug import check_loader 191 # check_loader(train_loader, n_samples=4) 192 # check_loader(val_loader, n_samples=4) 193 return 194 195 # Check for 2D or 3D training 196 is_2d = False 197 z, y, x = patch_shape 198 is_2d = z == 1 199 200 if is_2d: 201 model = get_2d_model(out_channels=2) 202 else: 203 model = get_3d_model(out_channels=2) 204 205 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 206 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 207 208 # Self training functionality. 209 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 210 loss = self_training.DefaultSelfTrainingLoss() 211 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 212 213 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 214 trainer = self_training.MeanTeacherTrainer( 215 name=name, 216 model=model, 217 optimizer=optimizer, 218 lr_scheduler=scheduler, 219 pseudo_labeler=pseudo_labeler, 220 unsupervised_loss=loss, 221 unsupervised_loss_and_metric=loss_and_metric, 222 supervised_train_loader=train_loader, 223 unsupervised_train_loader=unsupervised_train_loader, 224 supervised_val_loader=val_loader, 225 unsupervised_val_loader=unsupervised_val_loader, 226 supervised_loss=loss, 227 supervised_loss_and_metric=loss_and_metric, 228 logger=self_training.SelfTrainingTensorboardLogger, 229 mixed_precision=True, 230 device=device, 231 log_image_interval=100, 232 compile_model=False, 233 save_root=save_root, 234 ) 235 trainer.fit(n_iterations)
def
weak_augmentations(p: float = 0.75) -> <built-in function callable>:
16def weak_augmentations(p: float = 0.75) -> callable: 17 """The weak augmentations used in the unsupervised data loader. 18 19 Args: 20 p: The probability for applying one of the augmentations. 21 22 Returns: 23 The transformation function applying the augmentation. 24 """ 25 norm = torch_em.transform.raw.standardize 26 aug = transforms.Compose([ 27 norm, 28 transforms.RandomApply([torch_em.transform.raw.GaussianBlur()], p=p), 29 transforms.RandomApply([torch_em.transform.raw.AdditiveGaussianNoise( 30 scale=(0, 0.15), clip_kwargs=False)], p=p 31 ), 32 ]) 33 return torch_em.transform.raw.get_raw_transform(normalizer=norm, augmentation1=aug)
The weak augmentations used in the unsupervised data loader.
Arguments:
- p: The probability for applying one of the augmentations.
Returns:
The transformation function applying the augmentation.
def
drop_mask_channel(x):
class
ComposedTransform:
class
ChannelSplitterSampler:
def
get_unsupervised_loader( data_paths: Tuple[str], raw_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], sample_mask_paths: Optional[Tuple[str]] = None, sampler: Optional[<built-in function callable>] = None, exclude_top_and_bottom: bool = False) -> torch.utils.data.dataloader.DataLoader:
56def get_unsupervised_loader( 57 data_paths: Tuple[str], 58 raw_key: str, 59 patch_shape: Tuple[int, int, int], 60 batch_size: int, 61 n_samples: Optional[int], 62 sample_mask_paths: Optional[Tuple[str]] = None, 63 sampler: Optional[callable] = None, 64 exclude_top_and_bottom: bool = False, 65) -> torch.utils.data.DataLoader: 66 """Get a dataloader for unsupervised segmentation training. 67 68 Args: 69 data_paths: The filepaths to the hdf5 files containing the training data. 70 raw_key: The key that holds the raw data inside of the hdf5. 71 patch_shape: The patch shape used for a training example. 72 In order to run 2d training pass a patch shape with a singleton in the z-axis, 73 e.g. 'patch_shape = [1, 512, 512]'. 74 batch_size: The batch size for training. 75 n_samples: The number of samples per epoch. By default this will be estimated 76 based on the patch_shape and size of the volumes used for training. 77 exclude_top_and_bottom: Whether to exluce the five top and bottom slices to 78 avoid artifacts at the border of tomograms. 79 sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram. 80 sampler: Accept or reject patches based on a condition. 81 82 Returns: 83 The PyTorch dataloader. 84 """ 85 # We exclude the top and bottom slices where the tomogram reconstruction is bad. 86 # TODO this seems unneccesary if we have a boundary mask - remove? 87 if exclude_top_and_bottom: 88 roi = np.s_[5:-5, :, :] 89 else: 90 roi = None 91 # stack tomograms and masks and write to temp files to use as input to RawDataset() 92 if sample_mask_paths is not None: 93 assert len(data_paths) == len(sample_mask_paths), \ 94 f"Expected equal number of data_paths and and sample_masks_paths, got {len(data_paths)} data paths and {len(sample_mask_paths)} mask paths." 95 96 stacked_paths = [] 97 for i, (data_path, mask_path) in enumerate(zip(data_paths, sample_mask_paths)): 98 raw = read_mrc(data_path)[0] 99 mask = read_mrc(mask_path)[0] 100 stacked = np.stack([raw, mask], axis=0) 101 102 tmp_path = f"/tmp/stacked{i}_{uuid.uuid4().hex}.h5" 103 with h5py.File(tmp_path, "w") as f: 104 f.create_dataset("raw", data=stacked, compression="gzip") 105 stacked_paths.append(tmp_path) 106 107 # update variables for RawDataset() 108 data_paths = tuple(stacked_paths) 109 base_transform = torch_em.transform.get_raw_transform() 110 raw_transform = ComposedTransform(base_transform, drop_mask_channel) 111 sampler = ChannelSplitterSampler(sampler) 112 with_channels = True 113 else: 114 raw_transform = torch_em.transform.get_raw_transform() 115 with_channels = False 116 sampler = None 117 118 _, ndim = _determine_ndim(patch_shape) 119 transform = torch_em.transform.get_augmentations(ndim=ndim) 120 121 if n_samples is None: 122 n_samples_per_ds = None 123 else: 124 n_samples_per_ds = int(n_samples / len(data_paths)) 125 126 augmentations = (weak_augmentations(), weak_augmentations()) 127 128 datasets = [ 129 torch_em.data.RawDataset(path, raw_key, patch_shape, raw_transform, transform, roi=roi, 130 n_samples=n_samples_per_ds, sampler=sampler, ndim=ndim, with_channels=with_channels, augmentations=augmentations) 131 for path in data_paths 132 ] 133 ds = torch.utils.data.ConcatDataset(datasets) 134 135 num_workers = 4 * batch_size 136 loader = torch_em.segmentation.get_data_loader(ds, batch_size=batch_size, 137 num_workers=num_workers, shuffle=True) 138 return loader
Get a dataloader for unsupervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- exclude_top_and_bottom: Whether to exluce the five top and bottom slices to avoid artifacts at the border of tomograms.
- sample_mask_paths: The filepaths to the corresponding sample masks for each tomogram.
- sampler: Accept or reject patches based on a condition.
Returns:
The PyTorch dataloader.
def
semisupervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: str, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False) -> None:
143def semisupervised_training( 144 name: str, 145 train_paths: Tuple[str], 146 val_paths: Tuple[str], 147 label_key: str, 148 patch_shape: Tuple[int, int, int], 149 save_root: str, 150 raw_key: str = "raw", 151 batch_size: int = 1, 152 lr: float = 1e-4, 153 n_iterations: int = int(1e5), 154 n_samples_train: Optional[int] = None, 155 n_samples_val: Optional[int] = None, 156 check: bool = False, 157) -> None: 158 """Run semi-supervised segmentation training. 159 160 Args: 161 name: The name for the checkpoint to be trained. 162 train_paths: Filepaths to the hdf5 files for the training data. 163 val_paths: Filepaths to the df5 files for the validation data. 164 label_key: The key that holds the labels inside of the hdf5. 165 patch_shape: The patch shape used for a training example. 166 In order to run 2d training pass a patch shape with a singleton in the z-axis, 167 e.g. 'patch_shape = [1, 512, 512]'. 168 save_root: Folder where the checkpoint will be saved. 169 raw_key: The key that holds the raw data inside of the hdf5. 170 batch_size: The batch size for training. 171 lr: The initial learning rate. 172 n_iterations: The number of iterations to train for. 173 n_samples_train: The number of train samples per epoch. By default this will be estimated 174 based on the patch_shape and size of the volumes used for training. 175 n_samples_val: The number of val samples per epoch. By default this will be estimated 176 based on the patch_shape and size of the volumes used for validation. 177 check: Whether to check the training and validation loaders instead of running training. 178 """ 179 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 180 n_samples=n_samples_train) 181 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 182 n_samples=n_samples_val) 183 184 unsupervised_train_loader = get_unsupervised_loader(train_paths, raw_key, patch_shape, batch_size, 185 n_samples=n_samples_train) 186 unsupervised_val_loader = get_unsupervised_loader(val_paths, raw_key, patch_shape, batch_size, 187 n_samples=n_samples_val) 188 189 # TODO check the semisup loader 190 if check: 191 # from torch_em.util.debug import check_loader 192 # check_loader(train_loader, n_samples=4) 193 # check_loader(val_loader, n_samples=4) 194 return 195 196 # Check for 2D or 3D training 197 is_2d = False 198 z, y, x = patch_shape 199 is_2d = z == 1 200 201 if is_2d: 202 model = get_2d_model(out_channels=2) 203 else: 204 model = get_3d_model(out_channels=2) 205 206 optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 207 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5) 208 209 # Self training functionality. 210 pseudo_labeler = self_training.DefaultPseudoLabeler(confidence_threshold=0.9) 211 loss = self_training.DefaultSelfTrainingLoss() 212 loss_and_metric = self_training.DefaultSelfTrainingLossAndMetric() 213 214 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 215 trainer = self_training.MeanTeacherTrainer( 216 name=name, 217 model=model, 218 optimizer=optimizer, 219 lr_scheduler=scheduler, 220 pseudo_labeler=pseudo_labeler, 221 unsupervised_loss=loss, 222 unsupervised_loss_and_metric=loss_and_metric, 223 supervised_train_loader=train_loader, 224 unsupervised_train_loader=unsupervised_train_loader, 225 supervised_val_loader=val_loader, 226 unsupervised_val_loader=unsupervised_val_loader, 227 supervised_loss=loss, 228 supervised_loss_and_metric=loss_and_metric, 229 logger=self_training.SelfTrainingTensorboardLogger, 230 mixed_precision=True, 231 device=device, 232 log_image_interval=100, 233 compile_model=False, 234 save_root=save_root, 235 ) 236 trainer.fit(n_iterations)
Run semi-supervised segmentation training.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.