synapse_net.training.supervised_training
1import os 2from glob import glob 3from typing import Optional, Tuple 4 5import torch 6import torch_em 7from sklearn.model_selection import train_test_split 8from torch_em.model import AnisotropicUNet, UNet2d 9 10 11def get_3d_model( 12 out_channels: int, 13 in_channels: int = 1, 14 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 15 initial_features: int = 32, 16 final_activation: str = "Sigmoid", 17) -> torch.nn.Module: 18 """Get the U-Net model for 3D segmentation tasks. 19 20 Args: 21 out_channels: The number of output channels of the network. 22 scale_factors: The downscaling factors for each level of the U-Net encoder. 23 initial_features: The number of features in the first level of the U-Net. 24 The number of features increases by a factor of two in each level. 25 final_activation: The activation applied to the last output layer. 26 27 Returns: 28 The U-Net. 29 """ 30 model = AnisotropicUNet( 31 scale_factors=scale_factors, 32 in_channels=in_channels, 33 out_channels=out_channels, 34 initial_features=initial_features, 35 gain=2, 36 final_activation=final_activation, 37 ) 38 return model 39 40 41def get_2d_model( 42 out_channels: int, 43 in_channels: int = 1, 44 initial_features: int = 32, 45 final_activation: str = "Sigmoid", 46) -> torch.nn.Module: 47 """Get the U-Net model for 2D segmentation tasks. 48 49 Args: 50 out_channels: The number of output channels of the network. 51 initial_features: The number of features in the first level of the U-Net. 52 The number of features increases by a factor of two in each level. 53 final_activation: The activation applied to the last output layer. 54 55 Returns: 56 The U-Net. 57 """ 58 model = UNet2d( 59 in_channels=in_channels, 60 out_channels=out_channels, 61 initial_features=initial_features, 62 gain=2, 63 depth=4, 64 final_activation=final_activation, 65 ) 66 return model 67 68 69def _adjust_patch_shape(data_shape, patch_shape): 70 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 71 if data_shape == 2 and len(patch_shape) == 3: 72 return patch_shape[1:] # Remove the leading dimension in patch_shape 73 return patch_shape # Return the original patch_shape for 3D data 74 75 76def _determine_ndim(patch_shape): 77 # Check for 2D or 3D training 78 try: 79 z, y, x = patch_shape 80 except ValueError: 81 y, x = patch_shape 82 z = 1 83 is_2d = z == 1 84 ndim = 2 if is_2d else 3 85 return is_2d, ndim 86 87 88def get_supervised_loader( 89 data_paths: Tuple[str], 90 raw_key: str, 91 label_key: str, 92 patch_shape: Tuple[int, int, int], 93 batch_size: int, 94 n_samples: Optional[int], 95 add_boundary_transform: bool = True, 96 label_dtype=torch.float32, 97 rois: Optional[Tuple[Tuple[slice]]] = None, 98 sampler: Optional[callable] = None, 99 ignore_label: Optional[int] = None, 100 label_transform: Optional[callable] = None, 101 label_paths: Optional[Tuple[str]] = None, 102 **loader_kwargs, 103) -> torch.utils.data.DataLoader: 104 """Get a dataloader for supervised segmentation training. 105 106 Args: 107 data_paths: The filepaths to the hdf5 files containing the training data. 108 raw_key: The key that holds the raw data inside of the hdf5. 109 label_key: The key that holds the labels inside of the hdf5. 110 patch_shape: The patch shape used for a training example. 111 In order to run 2d training pass a patch shape with a singleton in the z-axis, 112 e.g. 'patch_shape = [1, 512, 512]'. 113 batch_size: The batch size for training. 114 n_samples: The number of samples per epoch. By default this will be estimated 115 based on the patch_shape and size of the volumes used for training. 116 add_boundary_transform: Whether to add a boundary channel to the training data. 117 label_dtype: The datatype of the labels returned by the dataloader. 118 rois: Optional region of interests for training. 119 sampler: Optional sampler for selecting blocks for training. 120 By default a minimum instance sampler will be used. 121 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 122 ignored in the loss computation. By default this option is not used. 123 label_transform: Label transform that is applied to the segmentation to compute the targets. 124 If no label transform is passed (the default) a boundary transform is used. 125 label_paths: Optional paths containing the labels / annotations for training. 126 If not given, the labels are expected to be contained in the `data_paths`. 127 loader_kwargs: Additional keyword arguments for the dataloader. 128 129 Returns: 130 The PyTorch dataloader. 131 """ 132 _, ndim = _determine_ndim(patch_shape) 133 if label_transform is not None: # A specific label transform was passed, do nothing. 134 pass 135 elif add_boundary_transform: 136 if ignore_label is None: 137 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 138 else: 139 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 140 add_binary_target=True, ignore_label=ignore_label 141 ) 142 143 else: 144 if ignore_label is not None: 145 raise NotImplementedError 146 label_transform = torch_em.transform.label.connected_components 147 148 if ndim == 2: 149 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 150 transform = torch_em.transform.Compose( 151 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 152 ) 153 else: 154 transform = torch_em.transform.Compose( 155 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 156 ) 157 158 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 159 shuffle = loader_kwargs.pop("shuffle", True) 160 161 if sampler is None: 162 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 163 164 if label_paths is None: 165 label_paths = data_paths 166 elif len(label_paths) != len(data_paths): 167 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 168 169 loader = torch_em.default_segmentation_loader( 170 data_paths, raw_key, 171 label_paths, label_key, sampler=sampler, 172 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 173 is_seg_dataset=True, label_transform=label_transform, transform=transform, 174 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 175 label_dtype=label_dtype, rois=rois, **loader_kwargs, 176 ) 177 return loader 178 179 180def supervised_training( 181 name: str, 182 train_paths: Tuple[str], 183 val_paths: Tuple[str], 184 label_key: str, 185 patch_shape: Tuple[int, int, int], 186 save_root: Optional[str] = None, 187 raw_key: str = "raw", 188 batch_size: int = 1, 189 lr: float = 1e-4, 190 n_iterations: int = int(1e5), 191 train_label_paths: Optional[Tuple[str]] = None, 192 val_label_paths: Optional[Tuple[str]] = None, 193 train_rois: Optional[Tuple[Tuple[slice]]] = None, 194 val_rois: Optional[Tuple[Tuple[slice]]] = None, 195 sampler: Optional[callable] = None, 196 n_samples_train: Optional[int] = None, 197 n_samples_val: Optional[int] = None, 198 check: bool = False, 199 ignore_label: Optional[int] = None, 200 label_transform: Optional[callable] = None, 201 in_channels: int = 1, 202 out_channels: int = 2, 203 mask_channel: bool = False, 204 **loader_kwargs, 205): 206 """Run supervised segmentation training. 207 208 This function trains a UNet for predicting outputs for segmentation. 209 Expects instance labels and converts them to boundary targets. 210 This behaviour can be changed by passing custom arguments for `label_transform` 211 and/or `out_channels`. 212 213 Args: 214 name: The name for the checkpoint to be trained. 215 train_paths: Filepaths to the hdf5 files for the training data. 216 val_paths: Filepaths to the df5 files for the validation data. 217 label_key: The key that holds the labels inside of the hdf5. 218 patch_shape: The patch shape used for a training example. 219 In order to run 2d training pass a patch shape with a singleton in the z-axis, 220 e.g. 'patch_shape = [1, 512, 512]'. 221 save_root: Folder where the checkpoint will be saved. 222 raw_key: The key that holds the raw data inside of the hdf5. 223 batch_size: The batch size for training. 224 lr: The initial learning rate. 225 n_iterations: The number of iterations to train for. 226 train_label_paths: Optional paths containing the label data for training. 227 If not given, the labels are expected to be part of `train_paths`. 228 val_label_paths: Optional paths containing the label data for validation. 229 If not given, the labels are expected to be part of `val_paths`. 230 train_rois: Optional region of interests for training. 231 val_rois: Optional region of interests for validation. 232 sampler: Optional sampler for selecting blocks for training. 233 By default a minimum instance sampler will be used. 234 n_samples_train: The number of train samples per epoch. By default this will be estimated 235 based on the patch_shape and size of the volumes used for training. 236 n_samples_val: The number of val samples per epoch. By default this will be estimated 237 based on the patch_shape and size of the volumes used for validation. 238 check: Whether to check the training and validation loaders instead of running training. 239 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 240 ignored in the loss computation. By default this option is not used. 241 label_transform: Label transform that is applied to the segmentation to compute the targets. 242 If no label transform is passed (the default) a boundary transform is used. 243 out_channels: The number of output channels of the UNet. 244 mask_channel: Whether the last channels in the labels should be used for masking the loss. 245 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 246 loader_kwargs: Additional keyword arguments for the dataloader. 247 """ 248 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 249 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 250 ignore_label=ignore_label, label_transform=label_transform, 251 label_paths=train_label_paths, **loader_kwargs) 252 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 253 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 254 ignore_label=ignore_label, label_transform=label_transform, 255 label_paths=val_label_paths, **loader_kwargs) 256 257 if check: 258 from torch_em.util.debug import check_loader 259 check_loader(train_loader, n_samples=4) 260 check_loader(val_loader, n_samples=4) 261 return 262 263 is_2d, _ = _determine_ndim(patch_shape) 264 if is_2d: 265 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 266 else: 267 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 268 269 loss, metric = None, None 270 # No ignore label -> we can use default loss. 271 if ignore_label is None and not mask_channel: 272 pass 273 # If we have an ignore label the loss and metric have to be modified 274 # so that the ignore mask is not used in the gradient calculation. 275 elif ignore_label is not None: 276 loss = torch_em.loss.LossWrapper( 277 loss=torch_em.loss.DiceLoss(), 278 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 279 ignore_label=ignore_label, masking_method="multiply", 280 ) 281 ) 282 metric = loss 283 elif mask_channel: 284 loss = torch_em.loss.LossWrapper( 285 loss=torch_em.loss.DiceLoss(), 286 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 287 masking_method="crop" if out_channels == 1 else "multiply") 288 ) 289 metric = loss 290 else: 291 raise ValueError 292 293 trainer = torch_em.default_segmentation_trainer( 294 name=name, 295 model=model, 296 train_loader=train_loader, 297 val_loader=val_loader, 298 learning_rate=lr, 299 mixed_precision=True, 300 log_image_interval=100, 301 compile_model=False, 302 save_root=save_root, 303 loss=loss, 304 metric=metric, 305 ) 306 trainer.fit(n_iterations) 307 308 309def _derive_key_from_files(files, key): 310 # Get all file extensions (general wild-cards may pick up files with multiple extensions). 311 extensions = list(set([os.path.splitext(ff)[1] for ff in files])) 312 313 # If we have more than 1 file extension we just use the key that was passed, 314 # as it is unclear how to derive a consistent key. 315 if len(extensions) > 1: 316 return files, key 317 318 ext = extensions[0] 319 extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} 320 321 # Derive the key from the extension if the key is None. 322 if key is None and ext in extension_to_key: 323 key = extension_to_key[ext] 324 # If the key is None and can't be derived raise an error. 325 elif key is None and ext not in extension_to_key: 326 raise ValueError( 327 f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." 328 ) 329 # If the key was passed and doesn't match the extension raise an error. 330 elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: 331 raise ValueError( 332 f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." 333 ) 334 return files, key 335 336 337def _parse_input_folder(folder, pattern, key): 338 files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) 339 return _derive_key_from_files(files, key) 340 341 342def _parse_input_files(args): 343 train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) 344 train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) 345 if len(train_image_paths) != len(train_label_paths): 346 raise ValueError( 347 f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." 348 f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." 349 ) 350 351 if args.val_folder is None: 352 if args.val_label_folder is not None: 353 raise ValueError("You have passed a val_label_folder, but not a val_folder.") 354 train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( 355 train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 356 ) 357 else: 358 if args.val_label_folder is None: 359 raise ValueError("You have passed a val_folder, but not a val_label_folder.") 360 val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key) 361 val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) 362 363 return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key 364 365 366# TODO enable initialization with a pre-trained model. 367def main(): 368 """@private 369 """ 370 import argparse 371 372 parser = argparse.ArgumentParser( 373 description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" 374 "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa 375 "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa 376 "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa 377 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 378 "Check out the information below for details on the arguments of this function.", 379 formatter_class=argparse.RawTextHelpFormatter 380 ) 381 parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") 382 parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") 383 384 # Folders with training data, containing raw/image data and labels. 385 parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") 386 parser.add_argument("--image_file_pattern", default="*", 387 help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") 388 parser.add_argument("--raw_key", 389 help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa 390 parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") 391 parser.add_argument("--label_file_pattern", default="*", 392 help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") 393 parser.add_argument("--label_key", 394 help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa 395 396 # Optional folders with validation data. If not given the training data is split into train/val. 397 parser.add_argument("--val_folder", 398 help="The input folder with the validation data. If not given the training data will be split for validation") # noqa 399 parser.add_argument("--val_label_folder", 400 help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa 401 402 # More optional argument: 403 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 404 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 405 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 406 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 407 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 408 args = parser.parse_args() 409 410 train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ 411 _parse_input_files(args) 412 413 supervised_training( 414 name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, 415 train_label_paths=train_label_paths, val_label_paths=val_label_paths, 416 raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, 417 n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, 418 check=args.check, 419 )
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
12def get_3d_model( 13 out_channels: int, 14 in_channels: int = 1, 15 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 16 initial_features: int = 32, 17 final_activation: str = "Sigmoid", 18) -> torch.nn.Module: 19 """Get the U-Net model for 3D segmentation tasks. 20 21 Args: 22 out_channels: The number of output channels of the network. 23 scale_factors: The downscaling factors for each level of the U-Net encoder. 24 initial_features: The number of features in the first level of the U-Net. 25 The number of features increases by a factor of two in each level. 26 final_activation: The activation applied to the last output layer. 27 28 Returns: 29 The U-Net. 30 """ 31 model = AnisotropicUNet( 32 scale_factors=scale_factors, 33 in_channels=in_channels, 34 out_channels=out_channels, 35 initial_features=initial_features, 36 gain=2, 37 final_activation=final_activation, 38 ) 39 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
42def get_2d_model( 43 out_channels: int, 44 in_channels: int = 1, 45 initial_features: int = 32, 46 final_activation: str = "Sigmoid", 47) -> torch.nn.Module: 48 """Get the U-Net model for 2D segmentation tasks. 49 50 Args: 51 out_channels: The number of output channels of the network. 52 initial_features: The number of features in the first level of the U-Net. 53 The number of features increases by a factor of two in each level. 54 final_activation: The activation applied to the last output layer. 55 56 Returns: 57 The U-Net. 58 """ 59 model = UNet2d( 60 in_channels=in_channels, 61 out_channels=out_channels, 62 initial_features=initial_features, 63 gain=2, 64 depth=4, 65 final_activation=final_activation, 66 ) 67 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
89def get_supervised_loader( 90 data_paths: Tuple[str], 91 raw_key: str, 92 label_key: str, 93 patch_shape: Tuple[int, int, int], 94 batch_size: int, 95 n_samples: Optional[int], 96 add_boundary_transform: bool = True, 97 label_dtype=torch.float32, 98 rois: Optional[Tuple[Tuple[slice]]] = None, 99 sampler: Optional[callable] = None, 100 ignore_label: Optional[int] = None, 101 label_transform: Optional[callable] = None, 102 label_paths: Optional[Tuple[str]] = None, 103 **loader_kwargs, 104) -> torch.utils.data.DataLoader: 105 """Get a dataloader for supervised segmentation training. 106 107 Args: 108 data_paths: The filepaths to the hdf5 files containing the training data. 109 raw_key: The key that holds the raw data inside of the hdf5. 110 label_key: The key that holds the labels inside of the hdf5. 111 patch_shape: The patch shape used for a training example. 112 In order to run 2d training pass a patch shape with a singleton in the z-axis, 113 e.g. 'patch_shape = [1, 512, 512]'. 114 batch_size: The batch size for training. 115 n_samples: The number of samples per epoch. By default this will be estimated 116 based on the patch_shape and size of the volumes used for training. 117 add_boundary_transform: Whether to add a boundary channel to the training data. 118 label_dtype: The datatype of the labels returned by the dataloader. 119 rois: Optional region of interests for training. 120 sampler: Optional sampler for selecting blocks for training. 121 By default a minimum instance sampler will be used. 122 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 123 ignored in the loss computation. By default this option is not used. 124 label_transform: Label transform that is applied to the segmentation to compute the targets. 125 If no label transform is passed (the default) a boundary transform is used. 126 label_paths: Optional paths containing the labels / annotations for training. 127 If not given, the labels are expected to be contained in the `data_paths`. 128 loader_kwargs: Additional keyword arguments for the dataloader. 129 130 Returns: 131 The PyTorch dataloader. 132 """ 133 _, ndim = _determine_ndim(patch_shape) 134 if label_transform is not None: # A specific label transform was passed, do nothing. 135 pass 136 elif add_boundary_transform: 137 if ignore_label is None: 138 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 139 else: 140 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 141 add_binary_target=True, ignore_label=ignore_label 142 ) 143 144 else: 145 if ignore_label is not None: 146 raise NotImplementedError 147 label_transform = torch_em.transform.label.connected_components 148 149 if ndim == 2: 150 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 151 transform = torch_em.transform.Compose( 152 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 153 ) 154 else: 155 transform = torch_em.transform.Compose( 156 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 157 ) 158 159 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 160 shuffle = loader_kwargs.pop("shuffle", True) 161 162 if sampler is None: 163 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 164 165 if label_paths is None: 166 label_paths = data_paths 167 elif len(label_paths) != len(data_paths): 168 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 169 170 loader = torch_em.default_segmentation_loader( 171 data_paths, raw_key, 172 label_paths, label_key, sampler=sampler, 173 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 174 is_seg_dataset=True, label_transform=label_transform, transform=transform, 175 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 176 label_dtype=label_dtype, rois=rois, **loader_kwargs, 177 ) 178 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- label_paths: Optional paths containing the labels / annotations for training.
If not given, the labels are expected to be contained in the
data_paths
. - loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, **loader_kwargs):
181def supervised_training( 182 name: str, 183 train_paths: Tuple[str], 184 val_paths: Tuple[str], 185 label_key: str, 186 patch_shape: Tuple[int, int, int], 187 save_root: Optional[str] = None, 188 raw_key: str = "raw", 189 batch_size: int = 1, 190 lr: float = 1e-4, 191 n_iterations: int = int(1e5), 192 train_label_paths: Optional[Tuple[str]] = None, 193 val_label_paths: Optional[Tuple[str]] = None, 194 train_rois: Optional[Tuple[Tuple[slice]]] = None, 195 val_rois: Optional[Tuple[Tuple[slice]]] = None, 196 sampler: Optional[callable] = None, 197 n_samples_train: Optional[int] = None, 198 n_samples_val: Optional[int] = None, 199 check: bool = False, 200 ignore_label: Optional[int] = None, 201 label_transform: Optional[callable] = None, 202 in_channels: int = 1, 203 out_channels: int = 2, 204 mask_channel: bool = False, 205 **loader_kwargs, 206): 207 """Run supervised segmentation training. 208 209 This function trains a UNet for predicting outputs for segmentation. 210 Expects instance labels and converts them to boundary targets. 211 This behaviour can be changed by passing custom arguments for `label_transform` 212 and/or `out_channels`. 213 214 Args: 215 name: The name for the checkpoint to be trained. 216 train_paths: Filepaths to the hdf5 files for the training data. 217 val_paths: Filepaths to the df5 files for the validation data. 218 label_key: The key that holds the labels inside of the hdf5. 219 patch_shape: The patch shape used for a training example. 220 In order to run 2d training pass a patch shape with a singleton in the z-axis, 221 e.g. 'patch_shape = [1, 512, 512]'. 222 save_root: Folder where the checkpoint will be saved. 223 raw_key: The key that holds the raw data inside of the hdf5. 224 batch_size: The batch size for training. 225 lr: The initial learning rate. 226 n_iterations: The number of iterations to train for. 227 train_label_paths: Optional paths containing the label data for training. 228 If not given, the labels are expected to be part of `train_paths`. 229 val_label_paths: Optional paths containing the label data for validation. 230 If not given, the labels are expected to be part of `val_paths`. 231 train_rois: Optional region of interests for training. 232 val_rois: Optional region of interests for validation. 233 sampler: Optional sampler for selecting blocks for training. 234 By default a minimum instance sampler will be used. 235 n_samples_train: The number of train samples per epoch. By default this will be estimated 236 based on the patch_shape and size of the volumes used for training. 237 n_samples_val: The number of val samples per epoch. By default this will be estimated 238 based on the patch_shape and size of the volumes used for validation. 239 check: Whether to check the training and validation loaders instead of running training. 240 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 241 ignored in the loss computation. By default this option is not used. 242 label_transform: Label transform that is applied to the segmentation to compute the targets. 243 If no label transform is passed (the default) a boundary transform is used. 244 out_channels: The number of output channels of the UNet. 245 mask_channel: Whether the last channels in the labels should be used for masking the loss. 246 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 247 loader_kwargs: Additional keyword arguments for the dataloader. 248 """ 249 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 250 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 251 ignore_label=ignore_label, label_transform=label_transform, 252 label_paths=train_label_paths, **loader_kwargs) 253 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 254 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 255 ignore_label=ignore_label, label_transform=label_transform, 256 label_paths=val_label_paths, **loader_kwargs) 257 258 if check: 259 from torch_em.util.debug import check_loader 260 check_loader(train_loader, n_samples=4) 261 check_loader(val_loader, n_samples=4) 262 return 263 264 is_2d, _ = _determine_ndim(patch_shape) 265 if is_2d: 266 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 267 else: 268 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 269 270 loss, metric = None, None 271 # No ignore label -> we can use default loss. 272 if ignore_label is None and not mask_channel: 273 pass 274 # If we have an ignore label the loss and metric have to be modified 275 # so that the ignore mask is not used in the gradient calculation. 276 elif ignore_label is not None: 277 loss = torch_em.loss.LossWrapper( 278 loss=torch_em.loss.DiceLoss(), 279 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 280 ignore_label=ignore_label, masking_method="multiply", 281 ) 282 ) 283 metric = loss 284 elif mask_channel: 285 loss = torch_em.loss.LossWrapper( 286 loss=torch_em.loss.DiceLoss(), 287 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 288 masking_method="crop" if out_channels == 1 else "multiply") 289 ) 290 metric = loss 291 else: 292 raise ValueError 293 294 trainer = torch_em.default_segmentation_trainer( 295 name=name, 296 model=model, 297 train_loader=train_loader, 298 val_loader=val_loader, 299 learning_rate=lr, 300 mixed_precision=True, 301 log_image_interval=100, 302 compile_model=False, 303 save_root=save_root, 304 loss=loss, 305 metric=metric, 306 ) 307 trainer.fit(n_iterations)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels
.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_label_paths: Optional paths containing the label data for training.
If not given, the labels are expected to be part of
train_paths
. - val_label_paths: Optional paths containing the label data for validation.
If not given, the labels are expected to be part of
val_paths
. - train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label
. - loader_kwargs: Additional keyword arguments for the dataloader.