synapse_net.training.supervised_training
1import os 2from glob import glob 3from typing import Optional, Tuple, Union 4 5import torch 6import torch_em 7from sklearn.model_selection import train_test_split 8from torch_em.model import AnisotropicUNet, UNet2d 9 10from synapse_net.inference.inference import get_model_path, get_available_models 11 12 13def get_3d_model( 14 out_channels: int, 15 in_channels: int = 1, 16 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 17 initial_features: int = 32, 18 final_activation: str = "Sigmoid", 19) -> torch.nn.Module: 20 """Get the U-Net model for 3D segmentation tasks. 21 22 Args: 23 out_channels: The number of output channels of the network. 24 scale_factors: The downscaling factors for each level of the U-Net encoder. 25 initial_features: The number of features in the first level of the U-Net. 26 The number of features increases by a factor of two in each level. 27 final_activation: The activation applied to the last output layer. 28 29 Returns: 30 The U-Net. 31 """ 32 model = AnisotropicUNet( 33 scale_factors=scale_factors, 34 in_channels=in_channels, 35 out_channels=out_channels, 36 initial_features=initial_features, 37 gain=2, 38 final_activation=final_activation, 39 ) 40 return model 41 42 43def get_2d_model( 44 out_channels: int, 45 in_channels: int = 1, 46 initial_features: int = 32, 47 final_activation: str = "Sigmoid", 48) -> torch.nn.Module: 49 """Get the U-Net model for 2D segmentation tasks. 50 51 Args: 52 out_channels: The number of output channels of the network. 53 initial_features: The number of features in the first level of the U-Net. 54 The number of features increases by a factor of two in each level. 55 final_activation: The activation applied to the last output layer. 56 57 Returns: 58 The U-Net. 59 """ 60 model = UNet2d( 61 in_channels=in_channels, 62 out_channels=out_channels, 63 initial_features=initial_features, 64 gain=2, 65 depth=4, 66 final_activation=final_activation, 67 ) 68 return model 69 70 71def _adjust_patch_shape(data_shape, patch_shape): 72 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 73 if data_shape == 2 and len(patch_shape) == 3: 74 return patch_shape[1:] # Remove the leading dimension in patch_shape 75 return patch_shape # Return the original patch_shape for 3D data 76 77 78def _determine_ndim(patch_shape): 79 # Check for 2D or 3D training 80 try: 81 z, y, x = patch_shape 82 except ValueError: 83 y, x = patch_shape 84 z = 1 85 is_2d = z == 1 86 ndim = 2 if is_2d else 3 87 return is_2d, ndim 88 89 90def get_supervised_loader( 91 data_paths: Tuple[str], 92 raw_key: str, 93 label_key: str, 94 patch_shape: Tuple[int, int, int], 95 batch_size: int, 96 n_samples: Optional[int], 97 add_boundary_transform: bool = True, 98 label_dtype=torch.float32, 99 rois: Optional[Tuple[Tuple[slice]]] = None, 100 sampler: Optional[Union[callable, bool]] = None, 101 ignore_label: Optional[int] = None, 102 label_transform: Optional[callable] = None, 103 label_paths: Optional[Tuple[str]] = None, 104 **loader_kwargs, 105) -> torch.utils.data.DataLoader: 106 """Get a dataloader for supervised segmentation training. 107 108 Args: 109 data_paths: The filepaths to the hdf5 files containing the training data. 110 raw_key: The key that holds the raw data inside of the hdf5. 111 label_key: The key that holds the labels inside of the hdf5. 112 patch_shape: The patch shape used for a training example. 113 In order to run 2d training pass a patch shape with a singleton in the z-axis, 114 e.g. 'patch_shape = [1, 512, 512]'. 115 batch_size: The batch size for training. 116 n_samples: The number of samples per epoch. By default this will be estimated 117 based on the patch_shape and size of the volumes used for training. 118 add_boundary_transform: Whether to add a boundary channel to the training data. 119 label_dtype: The datatype of the labels returned by the dataloader. 120 rois: Optional region of interests for training. 121 sampler: Optional sampler to accept or reject patches for training. 122 By default a minimum instance sampler will be used, pass `False` to disable. 123 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 124 ignored in the loss computation. By default this option is not used. 125 label_transform: Label transform that is applied to the segmentation to compute the targets. 126 If no label transform is passed (the default) a boundary transform is used. 127 label_paths: Optional paths containing the labels / annotations for training. 128 If not given, the labels are expected to be contained in the `data_paths`. 129 loader_kwargs: Additional keyword arguments for the dataloader. 130 131 Returns: 132 The PyTorch dataloader. 133 """ 134 _, ndim = _determine_ndim(patch_shape) 135 if label_transform is not None: # A specific label transform was passed, do nothing. 136 pass 137 elif add_boundary_transform: 138 if ignore_label is None: 139 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 140 else: 141 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 142 add_binary_target=True, ignore_label=ignore_label 143 ) 144 145 else: 146 if ignore_label is not None: 147 raise NotImplementedError 148 label_transform = torch_em.transform.label.connected_components 149 150 if ndim == 2: 151 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 152 transform = torch_em.transform.Compose( 153 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 154 ) 155 else: 156 transform = torch_em.transform.Compose( 157 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 158 ) 159 160 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 161 shuffle = loader_kwargs.pop("shuffle", True) 162 163 if sampler is None: 164 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 165 elif sampler is False: 166 sampler = None 167 168 if label_paths is None: 169 label_paths = data_paths 170 elif len(label_paths) != len(data_paths): 171 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 172 173 loader = torch_em.default_segmentation_loader( 174 data_paths, raw_key, 175 label_paths, label_key, sampler=sampler, 176 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 177 is_seg_dataset=True, label_transform=label_transform, transform=transform, 178 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 179 label_dtype=label_dtype, rois=rois, **loader_kwargs, 180 ) 181 return loader 182 183 184def supervised_training( 185 name: str, 186 train_paths: Tuple[str], 187 val_paths: Tuple[str], 188 label_key: str, 189 patch_shape: Tuple[int, int, int], 190 save_root: Optional[str] = None, 191 raw_key: str = "raw", 192 batch_size: int = 1, 193 lr: float = 1e-4, 194 n_iterations: int = int(1e5), 195 train_label_paths: Optional[Tuple[str]] = None, 196 val_label_paths: Optional[Tuple[str]] = None, 197 train_rois: Optional[Tuple[Tuple[slice]]] = None, 198 val_rois: Optional[Tuple[Tuple[slice]]] = None, 199 sampler: Optional[Union[callable, bool]] = None, 200 n_samples_train: Optional[int] = None, 201 n_samples_val: Optional[int] = None, 202 check: bool = False, 203 ignore_label: Optional[int] = None, 204 label_transform: Optional[callable] = None, 205 loss_fn: Optional[torch.nn.Module] = None, 206 in_channels: int = 1, 207 out_channels: int = 2, 208 mask_channel: bool = False, 209 checkpoint_path: Optional[str] = None, 210 save_every_kth_epoch: Optional[int] = None, 211 **loader_kwargs, 212): 213 """Run supervised segmentation training. 214 215 This function trains a UNet for predicting outputs for segmentation. 216 Expects instance labels and converts them to boundary targets. 217 This behaviour can be changed by passing custom arguments for `label_transform` 218 and/or `out_channels`. 219 220 Args: 221 name: The name for the checkpoint to be trained. 222 train_paths: Filepaths to the hdf5 files for the training data. 223 val_paths: Filepaths to the df5 files for the validation data. 224 label_key: The key that holds the labels inside of the hdf5. 225 patch_shape: The patch shape used for a training example. 226 In order to run 2d training pass a patch shape with a singleton in the z-axis, 227 e.g. 'patch_shape = [1, 512, 512]'. 228 save_root: Folder where the checkpoint will be saved. 229 raw_key: The key that holds the raw data inside of the hdf5. 230 batch_size: The batch size for training. 231 lr: The initial learning rate. 232 n_iterations: The number of iterations to train for. 233 train_label_paths: Optional paths containing the label data for training. 234 If not given, the labels are expected to be part of `train_paths`. 235 val_label_paths: Optional paths containing the label data for validation. 236 If not given, the labels are expected to be part of `val_paths`. 237 train_rois: Optional region of interests for training. 238 val_rois: Optional region of interests for validation. 239 sampler: Optional sampler for selecting patches for training. 240 By default a minimum instance sampler will be used. 241 n_samples_train: The number of train samples per epoch. By default this will be estimated 242 based on the patch_shape and size of the volumes used for training. 243 n_samples_val: The number of val samples per epoch. By default this will be estimated 244 based on the patch_shape and size of the volumes used for validation. 245 check: Whether to check the training and validation loaders instead of running training. 246 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 247 ignored in the loss computation. By default this option is not used. 248 label_transform: Label transform that is applied to the segmentation to compute the targets. 249 If no label transform is passed (the default) a boundary transform is used. 250 loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`. 251 out_channels: The number of output channels of the UNet. 252 mask_channel: Whether the last channels in the labels should be used for masking the loss. 253 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 254 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 255 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 256 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 257 loader_kwargs: Additional keyword arguments for the dataloader. 258 """ 259 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 260 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 261 ignore_label=ignore_label, label_transform=label_transform, 262 label_paths=train_label_paths, **loader_kwargs) 263 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 264 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 265 ignore_label=ignore_label, label_transform=label_transform, 266 label_paths=val_label_paths, **loader_kwargs) 267 268 if check: 269 from torch_em.util.debug import check_loader 270 check_loader(train_loader, n_samples=4) 271 check_loader(val_loader, n_samples=4) 272 return 273 274 is_2d, _ = _determine_ndim(patch_shape) 275 if checkpoint_path is not None: 276 model = torch_em.util.load_model(checkpoint=checkpoint_path) 277 elif is_2d: 278 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 279 else: 280 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 281 282 base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss() 283 metric = base_loss 284 285 # No ignore label -> we can use default loss. 286 if ignore_label is None and not mask_channel: 287 loss = base_loss 288 289 # If we have an ignore label the loss and metric have to be modified 290 # so that the ignore mask is not used in the gradient calculation. 291 elif ignore_label is not None: 292 loss = torch_em.loss.LossWrapper( 293 loss=base_loss, 294 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 295 ignore_label=ignore_label, masking_method="multiply", 296 ) 297 ) 298 metric = loss 299 elif mask_channel: 300 loss = torch_em.loss.LossWrapper( 301 loss=base_loss, 302 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 303 masking_method="crop" if out_channels == 1 else "multiply") 304 ) 305 metric = loss 306 else: 307 #raise ValueError 308 loss = base_loss 309 metric = loss 310 311 312 trainer = torch_em.default_segmentation_trainer( 313 name=name, 314 model=model, 315 train_loader=train_loader, 316 val_loader=val_loader, 317 learning_rate=lr, 318 mixed_precision=True, 319 log_image_interval=100, 320 compile_model=False, 321 save_root=save_root, 322 loss=loss, 323 metric=metric, 324 ) 325 trainer.fit(n_iterations, save_every_kth_epoch=save_every_kth_epoch) 326 327 328def _derive_key_from_files(files, key): 329 # Get all file extensions (general wild-cards may pick up files with multiple extensions). 330 extensions = list(set([os.path.splitext(ff)[1] for ff in files])) 331 332 # If we have more than 1 file extension we just use the key that was passed, 333 # as it is unclear how to derive a consistent key. 334 if len(extensions) > 1: 335 return files, key 336 337 ext = extensions[0] 338 extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} 339 340 # Derive the key from the extension if the key is None. 341 if key is None and ext in extension_to_key: 342 key = extension_to_key[ext] 343 # If the key is None and can't be derived raise an error. 344 elif key is None and ext not in extension_to_key: 345 raise ValueError( 346 f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." 347 ) 348 # If the key was passed and doesn't match the extension raise an error. 349 elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: 350 raise ValueError( 351 f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." 352 ) 353 return files, key 354 355 356def _parse_input_folder(folder, pattern, key): 357 files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) 358 return _derive_key_from_files(files, key) 359 360 361def _parse_input_files(args): 362 train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) 363 train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) 364 if len(train_image_paths) != len(train_label_paths): 365 raise ValueError( 366 f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." 367 f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." 368 ) 369 370 if args.val_folder is None: 371 if args.val_label_folder is not None: 372 raise ValueError("You have passed a val_label_folder, but not a val_folder.") 373 train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( 374 train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 375 ) 376 else: 377 if args.val_label_folder is None: 378 raise ValueError("You have passed a val_folder, but not a val_label_folder.") 379 val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key) 380 val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) 381 382 return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key 383 384 385def _parse_checkpoint(initial_model): 386 if initial_model is None: 387 return None 388 if os.path.exists(initial_model): 389 return initial_model 390 model_path = get_model_path(initial_model) 391 return model_path 392 393 394def main(): 395 """@private 396 """ 397 import argparse 398 399 parser = argparse.ArgumentParser( 400 description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" 401 "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa 402 "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa 403 "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa 404 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 405 "Check out the information below for details on the arguments of this function.", 406 formatter_class=argparse.RawTextHelpFormatter 407 ) 408 parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") 409 parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") 410 411 # Folders with training data, containing raw/image data and labels. 412 parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") 413 parser.add_argument("--image_file_pattern", default="*", 414 help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") 415 parser.add_argument("--raw_key", 416 help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa 417 parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") 418 parser.add_argument("--label_file_pattern", default="*", 419 help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") 420 parser.add_argument("--label_key", 421 help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa 422 423 # Optional folders with validation data. If not given the training data is split into train/val. 424 parser.add_argument("--val_folder", 425 help="The input folder with the validation data. If not given the training data will be split for validation") # noqa 426 parser.add_argument("--val_label_folder", 427 help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa 428 429 # Optional: choose a model for initializing the weights. 430 available_models = get_available_models() 431 parser.add_argument( 432 "--initial_model", 433 help="Choose a model checkpoint for weight initialization.\n" 434 "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n" 435 f"The following pretrained models are available: {available_models}.\n" 436 "If not given, the model will be randomly initialized." 437 ) 438 439 # More optional argument: 440 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 441 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 442 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 443 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 444 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 445 parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.") # noqa 446 parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") 447 args = parser.parse_args() 448 449 train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ 450 _parse_input_files(args) 451 checkpoint_path = _parse_checkpoint(args.initial_model) 452 453 supervised_training( 454 name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, 455 train_label_paths=train_label_paths, val_label_paths=val_label_paths, 456 raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, 457 n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, 458 check=args.check, n_iterations=args.n_iterations, save_root=args.save_root, 459 checkpoint_path=checkpoint_path 460 )
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
14def get_3d_model( 15 out_channels: int, 16 in_channels: int = 1, 17 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 18 initial_features: int = 32, 19 final_activation: str = "Sigmoid", 20) -> torch.nn.Module: 21 """Get the U-Net model for 3D segmentation tasks. 22 23 Args: 24 out_channels: The number of output channels of the network. 25 scale_factors: The downscaling factors for each level of the U-Net encoder. 26 initial_features: The number of features in the first level of the U-Net. 27 The number of features increases by a factor of two in each level. 28 final_activation: The activation applied to the last output layer. 29 30 Returns: 31 The U-Net. 32 """ 33 model = AnisotropicUNet( 34 scale_factors=scale_factors, 35 in_channels=in_channels, 36 out_channels=out_channels, 37 initial_features=initial_features, 38 gain=2, 39 final_activation=final_activation, 40 ) 41 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
44def get_2d_model( 45 out_channels: int, 46 in_channels: int = 1, 47 initial_features: int = 32, 48 final_activation: str = "Sigmoid", 49) -> torch.nn.Module: 50 """Get the U-Net model for 2D segmentation tasks. 51 52 Args: 53 out_channels: The number of output channels of the network. 54 initial_features: The number of features in the first level of the U-Net. 55 The number of features increases by a factor of two in each level. 56 final_activation: The activation applied to the last output layer. 57 58 Returns: 59 The U-Net. 60 """ 61 model = UNet2d( 62 in_channels=in_channels, 63 out_channels=out_channels, 64 initial_features=initial_features, 65 gain=2, 66 depth=4, 67 final_activation=final_activation, 68 ) 69 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Union[<built-in function callable>, bool, NoneType] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
91def get_supervised_loader( 92 data_paths: Tuple[str], 93 raw_key: str, 94 label_key: str, 95 patch_shape: Tuple[int, int, int], 96 batch_size: int, 97 n_samples: Optional[int], 98 add_boundary_transform: bool = True, 99 label_dtype=torch.float32, 100 rois: Optional[Tuple[Tuple[slice]]] = None, 101 sampler: Optional[Union[callable, bool]] = None, 102 ignore_label: Optional[int] = None, 103 label_transform: Optional[callable] = None, 104 label_paths: Optional[Tuple[str]] = None, 105 **loader_kwargs, 106) -> torch.utils.data.DataLoader: 107 """Get a dataloader for supervised segmentation training. 108 109 Args: 110 data_paths: The filepaths to the hdf5 files containing the training data. 111 raw_key: The key that holds the raw data inside of the hdf5. 112 label_key: The key that holds the labels inside of the hdf5. 113 patch_shape: The patch shape used for a training example. 114 In order to run 2d training pass a patch shape with a singleton in the z-axis, 115 e.g. 'patch_shape = [1, 512, 512]'. 116 batch_size: The batch size for training. 117 n_samples: The number of samples per epoch. By default this will be estimated 118 based on the patch_shape and size of the volumes used for training. 119 add_boundary_transform: Whether to add a boundary channel to the training data. 120 label_dtype: The datatype of the labels returned by the dataloader. 121 rois: Optional region of interests for training. 122 sampler: Optional sampler to accept or reject patches for training. 123 By default a minimum instance sampler will be used, pass `False` to disable. 124 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 125 ignored in the loss computation. By default this option is not used. 126 label_transform: Label transform that is applied to the segmentation to compute the targets. 127 If no label transform is passed (the default) a boundary transform is used. 128 label_paths: Optional paths containing the labels / annotations for training. 129 If not given, the labels are expected to be contained in the `data_paths`. 130 loader_kwargs: Additional keyword arguments for the dataloader. 131 132 Returns: 133 The PyTorch dataloader. 134 """ 135 _, ndim = _determine_ndim(patch_shape) 136 if label_transform is not None: # A specific label transform was passed, do nothing. 137 pass 138 elif add_boundary_transform: 139 if ignore_label is None: 140 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 141 else: 142 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 143 add_binary_target=True, ignore_label=ignore_label 144 ) 145 146 else: 147 if ignore_label is not None: 148 raise NotImplementedError 149 label_transform = torch_em.transform.label.connected_components 150 151 if ndim == 2: 152 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 153 transform = torch_em.transform.Compose( 154 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 155 ) 156 else: 157 transform = torch_em.transform.Compose( 158 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 159 ) 160 161 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 162 shuffle = loader_kwargs.pop("shuffle", True) 163 164 if sampler is None: 165 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 166 elif sampler is False: 167 sampler = None 168 169 if label_paths is None: 170 label_paths = data_paths 171 elif len(label_paths) != len(data_paths): 172 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 173 174 loader = torch_em.default_segmentation_loader( 175 data_paths, raw_key, 176 label_paths, label_key, sampler=sampler, 177 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 178 is_seg_dataset=True, label_transform=label_transform, transform=transform, 179 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 180 label_dtype=label_dtype, rois=rois, **loader_kwargs, 181 ) 182 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler to accept or reject patches for training.
By default a minimum instance sampler will be used, pass
Falseto disable. - ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- label_paths: Optional paths containing the labels / annotations for training.
If not given, the labels are expected to be contained in the
data_paths. - loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Union[<built-in function callable>, bool, NoneType] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, loss_fn: Optional[torch.nn.modules.module.Module] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, save_every_kth_epoch: Optional[int] = None, **loader_kwargs):
185def supervised_training( 186 name: str, 187 train_paths: Tuple[str], 188 val_paths: Tuple[str], 189 label_key: str, 190 patch_shape: Tuple[int, int, int], 191 save_root: Optional[str] = None, 192 raw_key: str = "raw", 193 batch_size: int = 1, 194 lr: float = 1e-4, 195 n_iterations: int = int(1e5), 196 train_label_paths: Optional[Tuple[str]] = None, 197 val_label_paths: Optional[Tuple[str]] = None, 198 train_rois: Optional[Tuple[Tuple[slice]]] = None, 199 val_rois: Optional[Tuple[Tuple[slice]]] = None, 200 sampler: Optional[Union[callable, bool]] = None, 201 n_samples_train: Optional[int] = None, 202 n_samples_val: Optional[int] = None, 203 check: bool = False, 204 ignore_label: Optional[int] = None, 205 label_transform: Optional[callable] = None, 206 loss_fn: Optional[torch.nn.Module] = None, 207 in_channels: int = 1, 208 out_channels: int = 2, 209 mask_channel: bool = False, 210 checkpoint_path: Optional[str] = None, 211 save_every_kth_epoch: Optional[int] = None, 212 **loader_kwargs, 213): 214 """Run supervised segmentation training. 215 216 This function trains a UNet for predicting outputs for segmentation. 217 Expects instance labels and converts them to boundary targets. 218 This behaviour can be changed by passing custom arguments for `label_transform` 219 and/or `out_channels`. 220 221 Args: 222 name: The name for the checkpoint to be trained. 223 train_paths: Filepaths to the hdf5 files for the training data. 224 val_paths: Filepaths to the df5 files for the validation data. 225 label_key: The key that holds the labels inside of the hdf5. 226 patch_shape: The patch shape used for a training example. 227 In order to run 2d training pass a patch shape with a singleton in the z-axis, 228 e.g. 'patch_shape = [1, 512, 512]'. 229 save_root: Folder where the checkpoint will be saved. 230 raw_key: The key that holds the raw data inside of the hdf5. 231 batch_size: The batch size for training. 232 lr: The initial learning rate. 233 n_iterations: The number of iterations to train for. 234 train_label_paths: Optional paths containing the label data for training. 235 If not given, the labels are expected to be part of `train_paths`. 236 val_label_paths: Optional paths containing the label data for validation. 237 If not given, the labels are expected to be part of `val_paths`. 238 train_rois: Optional region of interests for training. 239 val_rois: Optional region of interests for validation. 240 sampler: Optional sampler for selecting patches for training. 241 By default a minimum instance sampler will be used. 242 n_samples_train: The number of train samples per epoch. By default this will be estimated 243 based on the patch_shape and size of the volumes used for training. 244 n_samples_val: The number of val samples per epoch. By default this will be estimated 245 based on the patch_shape and size of the volumes used for validation. 246 check: Whether to check the training and validation loaders instead of running training. 247 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 248 ignored in the loss computation. By default this option is not used. 249 label_transform: Label transform that is applied to the segmentation to compute the targets. 250 If no label transform is passed (the default) a boundary transform is used. 251 loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`. 252 out_channels: The number of output channels of the UNet. 253 mask_channel: Whether the last channels in the labels should be used for masking the loss. 254 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 255 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 256 save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. 257 The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'. 258 loader_kwargs: Additional keyword arguments for the dataloader. 259 """ 260 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 261 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 262 ignore_label=ignore_label, label_transform=label_transform, 263 label_paths=train_label_paths, **loader_kwargs) 264 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 265 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 266 ignore_label=ignore_label, label_transform=label_transform, 267 label_paths=val_label_paths, **loader_kwargs) 268 269 if check: 270 from torch_em.util.debug import check_loader 271 check_loader(train_loader, n_samples=4) 272 check_loader(val_loader, n_samples=4) 273 return 274 275 is_2d, _ = _determine_ndim(patch_shape) 276 if checkpoint_path is not None: 277 model = torch_em.util.load_model(checkpoint=checkpoint_path) 278 elif is_2d: 279 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 280 else: 281 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 282 283 base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss() 284 metric = base_loss 285 286 # No ignore label -> we can use default loss. 287 if ignore_label is None and not mask_channel: 288 loss = base_loss 289 290 # If we have an ignore label the loss and metric have to be modified 291 # so that the ignore mask is not used in the gradient calculation. 292 elif ignore_label is not None: 293 loss = torch_em.loss.LossWrapper( 294 loss=base_loss, 295 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 296 ignore_label=ignore_label, masking_method="multiply", 297 ) 298 ) 299 metric = loss 300 elif mask_channel: 301 loss = torch_em.loss.LossWrapper( 302 loss=base_loss, 303 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 304 masking_method="crop" if out_channels == 1 else "multiply") 305 ) 306 metric = loss 307 else: 308 #raise ValueError 309 loss = base_loss 310 metric = loss 311 312 313 trainer = torch_em.default_segmentation_trainer( 314 name=name, 315 model=model, 316 train_loader=train_loader, 317 val_loader=val_loader, 318 learning_rate=lr, 319 mixed_precision=True, 320 log_image_interval=100, 321 compile_model=False, 322 save_root=save_root, 323 loss=loss, 324 metric=metric, 325 ) 326 trainer.fit(n_iterations, save_every_kth_epoch=save_every_kth_epoch)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_label_paths: Optional paths containing the label data for training.
If not given, the labels are expected to be part of
train_paths. - val_label_paths: Optional paths containing the label data for validation.
If not given, the labels are expected to be part of
val_paths. - train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting patches for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- loss_fn: Custom loss function. If None, will default to
torch_em.loss.DiceLoss. - out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label. - checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
- save_every_kth_epoch: Save checkpoints after every kth epoch in a separate file. The corresponding checkpoints will be saved with the naming scheme 'epoch-{epoch}.pt'.
- loader_kwargs: Additional keyword arguments for the dataloader.