synapse_net.training.supervised_training
1import os 2from glob import glob 3from typing import Optional, Tuple 4 5import torch 6import torch_em 7from sklearn.model_selection import train_test_split 8from torch_em.model import AnisotropicUNet, UNet2d 9 10from synapse_net.inference.inference import get_model_path, get_available_models 11 12 13def get_3d_model( 14 out_channels: int, 15 in_channels: int = 1, 16 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 17 initial_features: int = 32, 18 final_activation: str = "Sigmoid", 19) -> torch.nn.Module: 20 """Get the U-Net model for 3D segmentation tasks. 21 22 Args: 23 out_channels: The number of output channels of the network. 24 scale_factors: The downscaling factors for each level of the U-Net encoder. 25 initial_features: The number of features in the first level of the U-Net. 26 The number of features increases by a factor of two in each level. 27 final_activation: The activation applied to the last output layer. 28 29 Returns: 30 The U-Net. 31 """ 32 model = AnisotropicUNet( 33 scale_factors=scale_factors, 34 in_channels=in_channels, 35 out_channels=out_channels, 36 initial_features=initial_features, 37 gain=2, 38 final_activation=final_activation, 39 ) 40 return model 41 42 43def get_2d_model( 44 out_channels: int, 45 in_channels: int = 1, 46 initial_features: int = 32, 47 final_activation: str = "Sigmoid", 48) -> torch.nn.Module: 49 """Get the U-Net model for 2D segmentation tasks. 50 51 Args: 52 out_channels: The number of output channels of the network. 53 initial_features: The number of features in the first level of the U-Net. 54 The number of features increases by a factor of two in each level. 55 final_activation: The activation applied to the last output layer. 56 57 Returns: 58 The U-Net. 59 """ 60 model = UNet2d( 61 in_channels=in_channels, 62 out_channels=out_channels, 63 initial_features=initial_features, 64 gain=2, 65 depth=4, 66 final_activation=final_activation, 67 ) 68 return model 69 70 71def _adjust_patch_shape(data_shape, patch_shape): 72 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 73 if data_shape == 2 and len(patch_shape) == 3: 74 return patch_shape[1:] # Remove the leading dimension in patch_shape 75 return patch_shape # Return the original patch_shape for 3D data 76 77 78def _determine_ndim(patch_shape): 79 # Check for 2D or 3D training 80 try: 81 z, y, x = patch_shape 82 except ValueError: 83 y, x = patch_shape 84 z = 1 85 is_2d = z == 1 86 ndim = 2 if is_2d else 3 87 return is_2d, ndim 88 89 90def get_supervised_loader( 91 data_paths: Tuple[str], 92 raw_key: str, 93 label_key: str, 94 patch_shape: Tuple[int, int, int], 95 batch_size: int, 96 n_samples: Optional[int], 97 add_boundary_transform: bool = True, 98 label_dtype=torch.float32, 99 rois: Optional[Tuple[Tuple[slice]]] = None, 100 sampler: Optional[callable] = None, 101 ignore_label: Optional[int] = None, 102 label_transform: Optional[callable] = None, 103 label_paths: Optional[Tuple[str]] = None, 104 **loader_kwargs, 105) -> torch.utils.data.DataLoader: 106 """Get a dataloader for supervised segmentation training. 107 108 Args: 109 data_paths: The filepaths to the hdf5 files containing the training data. 110 raw_key: The key that holds the raw data inside of the hdf5. 111 label_key: The key that holds the labels inside of the hdf5. 112 patch_shape: The patch shape used for a training example. 113 In order to run 2d training pass a patch shape with a singleton in the z-axis, 114 e.g. 'patch_shape = [1, 512, 512]'. 115 batch_size: The batch size for training. 116 n_samples: The number of samples per epoch. By default this will be estimated 117 based on the patch_shape and size of the volumes used for training. 118 add_boundary_transform: Whether to add a boundary channel to the training data. 119 label_dtype: The datatype of the labels returned by the dataloader. 120 rois: Optional region of interests for training. 121 sampler: Optional sampler for selecting blocks for training. 122 By default a minimum instance sampler will be used. 123 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 124 ignored in the loss computation. By default this option is not used. 125 label_transform: Label transform that is applied to the segmentation to compute the targets. 126 If no label transform is passed (the default) a boundary transform is used. 127 label_paths: Optional paths containing the labels / annotations for training. 128 If not given, the labels are expected to be contained in the `data_paths`. 129 loader_kwargs: Additional keyword arguments for the dataloader. 130 131 Returns: 132 The PyTorch dataloader. 133 """ 134 _, ndim = _determine_ndim(patch_shape) 135 if label_transform is not None: # A specific label transform was passed, do nothing. 136 pass 137 elif add_boundary_transform: 138 if ignore_label is None: 139 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 140 else: 141 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 142 add_binary_target=True, ignore_label=ignore_label 143 ) 144 145 else: 146 if ignore_label is not None: 147 raise NotImplementedError 148 label_transform = torch_em.transform.label.connected_components 149 150 if ndim == 2: 151 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 152 transform = torch_em.transform.Compose( 153 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 154 ) 155 else: 156 transform = torch_em.transform.Compose( 157 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 158 ) 159 160 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 161 shuffle = loader_kwargs.pop("shuffle", True) 162 163 if sampler is None: 164 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 165 166 if label_paths is None: 167 label_paths = data_paths 168 elif len(label_paths) != len(data_paths): 169 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 170 171 loader = torch_em.default_segmentation_loader( 172 data_paths, raw_key, 173 label_paths, label_key, sampler=sampler, 174 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 175 is_seg_dataset=True, label_transform=label_transform, transform=transform, 176 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 177 label_dtype=label_dtype, rois=rois, **loader_kwargs, 178 ) 179 return loader 180 181 182def supervised_training( 183 name: str, 184 train_paths: Tuple[str], 185 val_paths: Tuple[str], 186 label_key: str, 187 patch_shape: Tuple[int, int, int], 188 save_root: Optional[str] = None, 189 raw_key: str = "raw", 190 batch_size: int = 1, 191 lr: float = 1e-4, 192 n_iterations: int = int(1e5), 193 train_label_paths: Optional[Tuple[str]] = None, 194 val_label_paths: Optional[Tuple[str]] = None, 195 train_rois: Optional[Tuple[Tuple[slice]]] = None, 196 val_rois: Optional[Tuple[Tuple[slice]]] = None, 197 sampler: Optional[callable] = None, 198 n_samples_train: Optional[int] = None, 199 n_samples_val: Optional[int] = None, 200 check: bool = False, 201 ignore_label: Optional[int] = None, 202 label_transform: Optional[callable] = None, 203 loss_fn: Optional[torch.nn.Module] = None, 204 in_channels: int = 1, 205 out_channels: int = 2, 206 mask_channel: bool = False, 207 checkpoint_path: Optional[str] = None, 208 **loader_kwargs, 209): 210 """Run supervised segmentation training. 211 212 This function trains a UNet for predicting outputs for segmentation. 213 Expects instance labels and converts them to boundary targets. 214 This behaviour can be changed by passing custom arguments for `label_transform` 215 and/or `out_channels`. 216 217 Args: 218 name: The name for the checkpoint to be trained. 219 train_paths: Filepaths to the hdf5 files for the training data. 220 val_paths: Filepaths to the df5 files for the validation data. 221 label_key: The key that holds the labels inside of the hdf5. 222 patch_shape: The patch shape used for a training example. 223 In order to run 2d training pass a patch shape with a singleton in the z-axis, 224 e.g. 'patch_shape = [1, 512, 512]'. 225 save_root: Folder where the checkpoint will be saved. 226 raw_key: The key that holds the raw data inside of the hdf5. 227 batch_size: The batch size for training. 228 lr: The initial learning rate. 229 n_iterations: The number of iterations to train for. 230 train_label_paths: Optional paths containing the label data for training. 231 If not given, the labels are expected to be part of `train_paths`. 232 val_label_paths: Optional paths containing the label data for validation. 233 If not given, the labels are expected to be part of `val_paths`. 234 train_rois: Optional region of interests for training. 235 val_rois: Optional region of interests for validation. 236 sampler: Optional sampler for selecting blocks for training. 237 By default a minimum instance sampler will be used. 238 n_samples_train: The number of train samples per epoch. By default this will be estimated 239 based on the patch_shape and size of the volumes used for training. 240 n_samples_val: The number of val samples per epoch. By default this will be estimated 241 based on the patch_shape and size of the volumes used for validation. 242 check: Whether to check the training and validation loaders instead of running training. 243 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 244 ignored in the loss computation. By default this option is not used. 245 label_transform: Label transform that is applied to the segmentation to compute the targets. 246 If no label transform is passed (the default) a boundary transform is used. 247 loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`. 248 out_channels: The number of output channels of the UNet. 249 mask_channel: Whether the last channels in the labels should be used for masking the loss. 250 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 251 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 252 loader_kwargs: Additional keyword arguments for the dataloader. 253 """ 254 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 255 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 256 ignore_label=ignore_label, label_transform=label_transform, 257 label_paths=train_label_paths, **loader_kwargs) 258 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 259 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 260 ignore_label=ignore_label, label_transform=label_transform, 261 label_paths=val_label_paths, **loader_kwargs) 262 263 if check: 264 from torch_em.util.debug import check_loader 265 check_loader(train_loader, n_samples=4) 266 check_loader(val_loader, n_samples=4) 267 return 268 269 is_2d, _ = _determine_ndim(patch_shape) 270 if checkpoint_path is not None: 271 model = torch_em.util.load_model(checkpoint=checkpoint_path) 272 elif is_2d: 273 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 274 else: 275 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 276 277 base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss() 278 metric = base_loss 279 280 # No ignore label -> we can use default loss. 281 if ignore_label is None and not mask_channel: 282 loss = base_loss 283 284 # If we have an ignore label the loss and metric have to be modified 285 # so that the ignore mask is not used in the gradient calculation. 286 elif ignore_label is not None: 287 loss = torch_em.loss.LossWrapper( 288 loss=base_loss, 289 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 290 ignore_label=ignore_label, masking_method="multiply", 291 ) 292 ) 293 metric = loss 294 elif mask_channel: 295 loss = torch_em.loss.LossWrapper( 296 loss=base_loss, 297 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 298 masking_method="crop" if out_channels == 1 else "multiply") 299 ) 300 metric = loss 301 else: 302 #raise ValueError 303 loss = base_loss 304 metric = loss 305 306 307 trainer = torch_em.default_segmentation_trainer( 308 name=name, 309 model=model, 310 train_loader=train_loader, 311 val_loader=val_loader, 312 learning_rate=lr, 313 mixed_precision=True, 314 log_image_interval=100, 315 compile_model=False, 316 save_root=save_root, 317 loss=loss, 318 metric=metric, 319 ) 320 trainer.fit(n_iterations) 321 322 323def _derive_key_from_files(files, key): 324 # Get all file extensions (general wild-cards may pick up files with multiple extensions). 325 extensions = list(set([os.path.splitext(ff)[1] for ff in files])) 326 327 # If we have more than 1 file extension we just use the key that was passed, 328 # as it is unclear how to derive a consistent key. 329 if len(extensions) > 1: 330 return files, key 331 332 ext = extensions[0] 333 extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} 334 335 # Derive the key from the extension if the key is None. 336 if key is None and ext in extension_to_key: 337 key = extension_to_key[ext] 338 # If the key is None and can't be derived raise an error. 339 elif key is None and ext not in extension_to_key: 340 raise ValueError( 341 f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." 342 ) 343 # If the key was passed and doesn't match the extension raise an error. 344 elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: 345 raise ValueError( 346 f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." 347 ) 348 return files, key 349 350 351def _parse_input_folder(folder, pattern, key): 352 files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) 353 return _derive_key_from_files(files, key) 354 355 356def _parse_input_files(args): 357 train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) 358 train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) 359 if len(train_image_paths) != len(train_label_paths): 360 raise ValueError( 361 f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." 362 f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." 363 ) 364 365 if args.val_folder is None: 366 if args.val_label_folder is not None: 367 raise ValueError("You have passed a val_label_folder, but not a val_folder.") 368 train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( 369 train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 370 ) 371 else: 372 if args.val_label_folder is None: 373 raise ValueError("You have passed a val_folder, but not a val_label_folder.") 374 val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key) 375 val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) 376 377 return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key 378 379 380def _parse_checkpoint(initial_model): 381 if initial_model is None: 382 return None 383 if os.path.exists(initial_model): 384 return initial_model 385 model_path = get_model_path(initial_model) 386 return model_path 387 388 389def main(): 390 """@private 391 """ 392 import argparse 393 394 parser = argparse.ArgumentParser( 395 description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" 396 "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa 397 "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa 398 "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa 399 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 400 "Check out the information below for details on the arguments of this function.", 401 formatter_class=argparse.RawTextHelpFormatter 402 ) 403 parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") 404 parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") 405 406 # Folders with training data, containing raw/image data and labels. 407 parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") 408 parser.add_argument("--image_file_pattern", default="*", 409 help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") 410 parser.add_argument("--raw_key", 411 help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa 412 parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") 413 parser.add_argument("--label_file_pattern", default="*", 414 help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") 415 parser.add_argument("--label_key", 416 help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa 417 418 # Optional folders with validation data. If not given the training data is split into train/val. 419 parser.add_argument("--val_folder", 420 help="The input folder with the validation data. If not given the training data will be split for validation") # noqa 421 parser.add_argument("--val_label_folder", 422 help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa 423 424 # Optional: choose a model for initializing the weights. 425 available_models = get_available_models() 426 parser.add_argument( 427 "--initial_model", 428 help="Choose a model checkpoint for weight initialization.\n" 429 "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n" 430 f"The following pretrained models are available: {available_models}.\n" 431 "If not given, the model will be randomly initialized." 432 ) 433 434 # More optional argument: 435 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 436 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 437 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 438 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 439 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 440 parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.") # noqa 441 parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") 442 args = parser.parse_args() 443 444 train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ 445 _parse_input_files(args) 446 checkpoint_path = _parse_checkpoint(args.initial_model) 447 448 supervised_training( 449 name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, 450 train_label_paths=train_label_paths, val_label_paths=val_label_paths, 451 raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, 452 n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, 453 check=args.check, n_iterations=args.n_iterations, save_root=args.save_root, 454 checkpoint_path=checkpoint_path 455 )
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
14def get_3d_model( 15 out_channels: int, 16 in_channels: int = 1, 17 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 18 initial_features: int = 32, 19 final_activation: str = "Sigmoid", 20) -> torch.nn.Module: 21 """Get the U-Net model for 3D segmentation tasks. 22 23 Args: 24 out_channels: The number of output channels of the network. 25 scale_factors: The downscaling factors for each level of the U-Net encoder. 26 initial_features: The number of features in the first level of the U-Net. 27 The number of features increases by a factor of two in each level. 28 final_activation: The activation applied to the last output layer. 29 30 Returns: 31 The U-Net. 32 """ 33 model = AnisotropicUNet( 34 scale_factors=scale_factors, 35 in_channels=in_channels, 36 out_channels=out_channels, 37 initial_features=initial_features, 38 gain=2, 39 final_activation=final_activation, 40 ) 41 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
44def get_2d_model( 45 out_channels: int, 46 in_channels: int = 1, 47 initial_features: int = 32, 48 final_activation: str = "Sigmoid", 49) -> torch.nn.Module: 50 """Get the U-Net model for 2D segmentation tasks. 51 52 Args: 53 out_channels: The number of output channels of the network. 54 initial_features: The number of features in the first level of the U-Net. 55 The number of features increases by a factor of two in each level. 56 final_activation: The activation applied to the last output layer. 57 58 Returns: 59 The U-Net. 60 """ 61 model = UNet2d( 62 in_channels=in_channels, 63 out_channels=out_channels, 64 initial_features=initial_features, 65 gain=2, 66 depth=4, 67 final_activation=final_activation, 68 ) 69 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
91def get_supervised_loader( 92 data_paths: Tuple[str], 93 raw_key: str, 94 label_key: str, 95 patch_shape: Tuple[int, int, int], 96 batch_size: int, 97 n_samples: Optional[int], 98 add_boundary_transform: bool = True, 99 label_dtype=torch.float32, 100 rois: Optional[Tuple[Tuple[slice]]] = None, 101 sampler: Optional[callable] = None, 102 ignore_label: Optional[int] = None, 103 label_transform: Optional[callable] = None, 104 label_paths: Optional[Tuple[str]] = None, 105 **loader_kwargs, 106) -> torch.utils.data.DataLoader: 107 """Get a dataloader for supervised segmentation training. 108 109 Args: 110 data_paths: The filepaths to the hdf5 files containing the training data. 111 raw_key: The key that holds the raw data inside of the hdf5. 112 label_key: The key that holds the labels inside of the hdf5. 113 patch_shape: The patch shape used for a training example. 114 In order to run 2d training pass a patch shape with a singleton in the z-axis, 115 e.g. 'patch_shape = [1, 512, 512]'. 116 batch_size: The batch size for training. 117 n_samples: The number of samples per epoch. By default this will be estimated 118 based on the patch_shape and size of the volumes used for training. 119 add_boundary_transform: Whether to add a boundary channel to the training data. 120 label_dtype: The datatype of the labels returned by the dataloader. 121 rois: Optional region of interests for training. 122 sampler: Optional sampler for selecting blocks for training. 123 By default a minimum instance sampler will be used. 124 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 125 ignored in the loss computation. By default this option is not used. 126 label_transform: Label transform that is applied to the segmentation to compute the targets. 127 If no label transform is passed (the default) a boundary transform is used. 128 label_paths: Optional paths containing the labels / annotations for training. 129 If not given, the labels are expected to be contained in the `data_paths`. 130 loader_kwargs: Additional keyword arguments for the dataloader. 131 132 Returns: 133 The PyTorch dataloader. 134 """ 135 _, ndim = _determine_ndim(patch_shape) 136 if label_transform is not None: # A specific label transform was passed, do nothing. 137 pass 138 elif add_boundary_transform: 139 if ignore_label is None: 140 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 141 else: 142 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 143 add_binary_target=True, ignore_label=ignore_label 144 ) 145 146 else: 147 if ignore_label is not None: 148 raise NotImplementedError 149 label_transform = torch_em.transform.label.connected_components 150 151 if ndim == 2: 152 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 153 transform = torch_em.transform.Compose( 154 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 155 ) 156 else: 157 transform = torch_em.transform.Compose( 158 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 159 ) 160 161 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 162 shuffle = loader_kwargs.pop("shuffle", True) 163 164 if sampler is None: 165 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 166 167 if label_paths is None: 168 label_paths = data_paths 169 elif len(label_paths) != len(data_paths): 170 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 171 172 loader = torch_em.default_segmentation_loader( 173 data_paths, raw_key, 174 label_paths, label_key, sampler=sampler, 175 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 176 is_seg_dataset=True, label_transform=label_transform, transform=transform, 177 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 178 label_dtype=label_dtype, rois=rois, **loader_kwargs, 179 ) 180 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- label_paths: Optional paths containing the labels / annotations for training.
If not given, the labels are expected to be contained in the
data_paths. - loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, loss_fn: Optional[torch.nn.modules.module.Module] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, **loader_kwargs):
183def supervised_training( 184 name: str, 185 train_paths: Tuple[str], 186 val_paths: Tuple[str], 187 label_key: str, 188 patch_shape: Tuple[int, int, int], 189 save_root: Optional[str] = None, 190 raw_key: str = "raw", 191 batch_size: int = 1, 192 lr: float = 1e-4, 193 n_iterations: int = int(1e5), 194 train_label_paths: Optional[Tuple[str]] = None, 195 val_label_paths: Optional[Tuple[str]] = None, 196 train_rois: Optional[Tuple[Tuple[slice]]] = None, 197 val_rois: Optional[Tuple[Tuple[slice]]] = None, 198 sampler: Optional[callable] = None, 199 n_samples_train: Optional[int] = None, 200 n_samples_val: Optional[int] = None, 201 check: bool = False, 202 ignore_label: Optional[int] = None, 203 label_transform: Optional[callable] = None, 204 loss_fn: Optional[torch.nn.Module] = None, 205 in_channels: int = 1, 206 out_channels: int = 2, 207 mask_channel: bool = False, 208 checkpoint_path: Optional[str] = None, 209 **loader_kwargs, 210): 211 """Run supervised segmentation training. 212 213 This function trains a UNet for predicting outputs for segmentation. 214 Expects instance labels and converts them to boundary targets. 215 This behaviour can be changed by passing custom arguments for `label_transform` 216 and/or `out_channels`. 217 218 Args: 219 name: The name for the checkpoint to be trained. 220 train_paths: Filepaths to the hdf5 files for the training data. 221 val_paths: Filepaths to the df5 files for the validation data. 222 label_key: The key that holds the labels inside of the hdf5. 223 patch_shape: The patch shape used for a training example. 224 In order to run 2d training pass a patch shape with a singleton in the z-axis, 225 e.g. 'patch_shape = [1, 512, 512]'. 226 save_root: Folder where the checkpoint will be saved. 227 raw_key: The key that holds the raw data inside of the hdf5. 228 batch_size: The batch size for training. 229 lr: The initial learning rate. 230 n_iterations: The number of iterations to train for. 231 train_label_paths: Optional paths containing the label data for training. 232 If not given, the labels are expected to be part of `train_paths`. 233 val_label_paths: Optional paths containing the label data for validation. 234 If not given, the labels are expected to be part of `val_paths`. 235 train_rois: Optional region of interests for training. 236 val_rois: Optional region of interests for validation. 237 sampler: Optional sampler for selecting blocks for training. 238 By default a minimum instance sampler will be used. 239 n_samples_train: The number of train samples per epoch. By default this will be estimated 240 based on the patch_shape and size of the volumes used for training. 241 n_samples_val: The number of val samples per epoch. By default this will be estimated 242 based on the patch_shape and size of the volumes used for validation. 243 check: Whether to check the training and validation loaders instead of running training. 244 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 245 ignored in the loss computation. By default this option is not used. 246 label_transform: Label transform that is applied to the segmentation to compute the targets. 247 If no label transform is passed (the default) a boundary transform is used. 248 loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`. 249 out_channels: The number of output channels of the UNet. 250 mask_channel: Whether the last channels in the labels should be used for masking the loss. 251 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 252 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 253 loader_kwargs: Additional keyword arguments for the dataloader. 254 """ 255 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 256 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 257 ignore_label=ignore_label, label_transform=label_transform, 258 label_paths=train_label_paths, **loader_kwargs) 259 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 260 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 261 ignore_label=ignore_label, label_transform=label_transform, 262 label_paths=val_label_paths, **loader_kwargs) 263 264 if check: 265 from torch_em.util.debug import check_loader 266 check_loader(train_loader, n_samples=4) 267 check_loader(val_loader, n_samples=4) 268 return 269 270 is_2d, _ = _determine_ndim(patch_shape) 271 if checkpoint_path is not None: 272 model = torch_em.util.load_model(checkpoint=checkpoint_path) 273 elif is_2d: 274 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 275 else: 276 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 277 278 base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss() 279 metric = base_loss 280 281 # No ignore label -> we can use default loss. 282 if ignore_label is None and not mask_channel: 283 loss = base_loss 284 285 # If we have an ignore label the loss and metric have to be modified 286 # so that the ignore mask is not used in the gradient calculation. 287 elif ignore_label is not None: 288 loss = torch_em.loss.LossWrapper( 289 loss=base_loss, 290 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 291 ignore_label=ignore_label, masking_method="multiply", 292 ) 293 ) 294 metric = loss 295 elif mask_channel: 296 loss = torch_em.loss.LossWrapper( 297 loss=base_loss, 298 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 299 masking_method="crop" if out_channels == 1 else "multiply") 300 ) 301 metric = loss 302 else: 303 #raise ValueError 304 loss = base_loss 305 metric = loss 306 307 308 trainer = torch_em.default_segmentation_trainer( 309 name=name, 310 model=model, 311 train_loader=train_loader, 312 val_loader=val_loader, 313 learning_rate=lr, 314 mixed_precision=True, 315 log_image_interval=100, 316 compile_model=False, 317 save_root=save_root, 318 loss=loss, 319 metric=metric, 320 ) 321 trainer.fit(n_iterations)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_label_paths: Optional paths containing the label data for training.
If not given, the labels are expected to be part of
train_paths. - val_label_paths: Optional paths containing the label data for validation.
If not given, the labels are expected to be part of
val_paths. - train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- loss_fn: Custom loss function. If None, will default to
torch_em.loss.DiceLoss. - out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label. - checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
- loader_kwargs: Additional keyword arguments for the dataloader.