synapse_net.training.supervised_training
1import os 2from glob import glob 3from typing import Optional, Tuple 4 5import torch 6import torch_em 7from sklearn.model_selection import train_test_split 8from torch_em.model import AnisotropicUNet, UNet2d 9 10from synapse_net.inference.inference import get_model_path, get_available_models 11 12 13def get_3d_model( 14 out_channels: int, 15 in_channels: int = 1, 16 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 17 initial_features: int = 32, 18 final_activation: str = "Sigmoid", 19) -> torch.nn.Module: 20 """Get the U-Net model for 3D segmentation tasks. 21 22 Args: 23 out_channels: The number of output channels of the network. 24 scale_factors: The downscaling factors for each level of the U-Net encoder. 25 initial_features: The number of features in the first level of the U-Net. 26 The number of features increases by a factor of two in each level. 27 final_activation: The activation applied to the last output layer. 28 29 Returns: 30 The U-Net. 31 """ 32 model = AnisotropicUNet( 33 scale_factors=scale_factors, 34 in_channels=in_channels, 35 out_channels=out_channels, 36 initial_features=initial_features, 37 gain=2, 38 final_activation=final_activation, 39 ) 40 return model 41 42 43def get_2d_model( 44 out_channels: int, 45 in_channels: int = 1, 46 initial_features: int = 32, 47 final_activation: str = "Sigmoid", 48) -> torch.nn.Module: 49 """Get the U-Net model for 2D segmentation tasks. 50 51 Args: 52 out_channels: The number of output channels of the network. 53 initial_features: The number of features in the first level of the U-Net. 54 The number of features increases by a factor of two in each level. 55 final_activation: The activation applied to the last output layer. 56 57 Returns: 58 The U-Net. 59 """ 60 model = UNet2d( 61 in_channels=in_channels, 62 out_channels=out_channels, 63 initial_features=initial_features, 64 gain=2, 65 depth=4, 66 final_activation=final_activation, 67 ) 68 return model 69 70 71def _adjust_patch_shape(data_shape, patch_shape): 72 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 73 if data_shape == 2 and len(patch_shape) == 3: 74 return patch_shape[1:] # Remove the leading dimension in patch_shape 75 return patch_shape # Return the original patch_shape for 3D data 76 77 78def _determine_ndim(patch_shape): 79 # Check for 2D or 3D training 80 try: 81 z, y, x = patch_shape 82 except ValueError: 83 y, x = patch_shape 84 z = 1 85 is_2d = z == 1 86 ndim = 2 if is_2d else 3 87 return is_2d, ndim 88 89 90def get_supervised_loader( 91 data_paths: Tuple[str], 92 raw_key: str, 93 label_key: str, 94 patch_shape: Tuple[int, int, int], 95 batch_size: int, 96 n_samples: Optional[int], 97 add_boundary_transform: bool = True, 98 label_dtype=torch.float32, 99 rois: Optional[Tuple[Tuple[slice]]] = None, 100 sampler: Optional[callable] = None, 101 ignore_label: Optional[int] = None, 102 label_transform: Optional[callable] = None, 103 label_paths: Optional[Tuple[str]] = None, 104 **loader_kwargs, 105) -> torch.utils.data.DataLoader: 106 """Get a dataloader for supervised segmentation training. 107 108 Args: 109 data_paths: The filepaths to the hdf5 files containing the training data. 110 raw_key: The key that holds the raw data inside of the hdf5. 111 label_key: The key that holds the labels inside of the hdf5. 112 patch_shape: The patch shape used for a training example. 113 In order to run 2d training pass a patch shape with a singleton in the z-axis, 114 e.g. 'patch_shape = [1, 512, 512]'. 115 batch_size: The batch size for training. 116 n_samples: The number of samples per epoch. By default this will be estimated 117 based on the patch_shape and size of the volumes used for training. 118 add_boundary_transform: Whether to add a boundary channel to the training data. 119 label_dtype: The datatype of the labels returned by the dataloader. 120 rois: Optional region of interests for training. 121 sampler: Optional sampler for selecting blocks for training. 122 By default a minimum instance sampler will be used. 123 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 124 ignored in the loss computation. By default this option is not used. 125 label_transform: Label transform that is applied to the segmentation to compute the targets. 126 If no label transform is passed (the default) a boundary transform is used. 127 label_paths: Optional paths containing the labels / annotations for training. 128 If not given, the labels are expected to be contained in the `data_paths`. 129 loader_kwargs: Additional keyword arguments for the dataloader. 130 131 Returns: 132 The PyTorch dataloader. 133 """ 134 _, ndim = _determine_ndim(patch_shape) 135 if label_transform is not None: # A specific label transform was passed, do nothing. 136 pass 137 elif add_boundary_transform: 138 if ignore_label is None: 139 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 140 else: 141 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 142 add_binary_target=True, ignore_label=ignore_label 143 ) 144 145 else: 146 if ignore_label is not None: 147 raise NotImplementedError 148 label_transform = torch_em.transform.label.connected_components 149 150 if ndim == 2: 151 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 152 transform = torch_em.transform.Compose( 153 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 154 ) 155 else: 156 transform = torch_em.transform.Compose( 157 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 158 ) 159 160 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 161 shuffle = loader_kwargs.pop("shuffle", True) 162 163 if sampler is None: 164 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 165 166 if label_paths is None: 167 label_paths = data_paths 168 elif len(label_paths) != len(data_paths): 169 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 170 171 loader = torch_em.default_segmentation_loader( 172 data_paths, raw_key, 173 label_paths, label_key, sampler=sampler, 174 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 175 is_seg_dataset=True, label_transform=label_transform, transform=transform, 176 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 177 label_dtype=label_dtype, rois=rois, **loader_kwargs, 178 ) 179 return loader 180 181 182def supervised_training( 183 name: str, 184 train_paths: Tuple[str], 185 val_paths: Tuple[str], 186 label_key: str, 187 patch_shape: Tuple[int, int, int], 188 save_root: Optional[str] = None, 189 raw_key: str = "raw", 190 batch_size: int = 1, 191 lr: float = 1e-4, 192 n_iterations: int = int(1e5), 193 train_label_paths: Optional[Tuple[str]] = None, 194 val_label_paths: Optional[Tuple[str]] = None, 195 train_rois: Optional[Tuple[Tuple[slice]]] = None, 196 val_rois: Optional[Tuple[Tuple[slice]]] = None, 197 sampler: Optional[callable] = None, 198 n_samples_train: Optional[int] = None, 199 n_samples_val: Optional[int] = None, 200 check: bool = False, 201 ignore_label: Optional[int] = None, 202 label_transform: Optional[callable] = None, 203 in_channels: int = 1, 204 out_channels: int = 2, 205 mask_channel: bool = False, 206 checkpoint_path: Optional[str] = None, 207 **loader_kwargs, 208): 209 """Run supervised segmentation training. 210 211 This function trains a UNet for predicting outputs for segmentation. 212 Expects instance labels and converts them to boundary targets. 213 This behaviour can be changed by passing custom arguments for `label_transform` 214 and/or `out_channels`. 215 216 Args: 217 name: The name for the checkpoint to be trained. 218 train_paths: Filepaths to the hdf5 files for the training data. 219 val_paths: Filepaths to the df5 files for the validation data. 220 label_key: The key that holds the labels inside of the hdf5. 221 patch_shape: The patch shape used for a training example. 222 In order to run 2d training pass a patch shape with a singleton in the z-axis, 223 e.g. 'patch_shape = [1, 512, 512]'. 224 save_root: Folder where the checkpoint will be saved. 225 raw_key: The key that holds the raw data inside of the hdf5. 226 batch_size: The batch size for training. 227 lr: The initial learning rate. 228 n_iterations: The number of iterations to train for. 229 train_label_paths: Optional paths containing the label data for training. 230 If not given, the labels are expected to be part of `train_paths`. 231 val_label_paths: Optional paths containing the label data for validation. 232 If not given, the labels are expected to be part of `val_paths`. 233 train_rois: Optional region of interests for training. 234 val_rois: Optional region of interests for validation. 235 sampler: Optional sampler for selecting blocks for training. 236 By default a minimum instance sampler will be used. 237 n_samples_train: The number of train samples per epoch. By default this will be estimated 238 based on the patch_shape and size of the volumes used for training. 239 n_samples_val: The number of val samples per epoch. By default this will be estimated 240 based on the patch_shape and size of the volumes used for validation. 241 check: Whether to check the training and validation loaders instead of running training. 242 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 243 ignored in the loss computation. By default this option is not used. 244 label_transform: Label transform that is applied to the segmentation to compute the targets. 245 If no label transform is passed (the default) a boundary transform is used. 246 out_channels: The number of output channels of the UNet. 247 mask_channel: Whether the last channels in the labels should be used for masking the loss. 248 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 249 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 250 loader_kwargs: Additional keyword arguments for the dataloader. 251 """ 252 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 253 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 254 ignore_label=ignore_label, label_transform=label_transform, 255 label_paths=train_label_paths, **loader_kwargs) 256 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 257 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 258 ignore_label=ignore_label, label_transform=label_transform, 259 label_paths=val_label_paths, **loader_kwargs) 260 261 if check: 262 from torch_em.util.debug import check_loader 263 check_loader(train_loader, n_samples=4) 264 check_loader(val_loader, n_samples=4) 265 return 266 267 is_2d, _ = _determine_ndim(patch_shape) 268 if checkpoint_path is not None: 269 model = torch_em.util.load_model(checkpoint=checkpoint_path) 270 elif is_2d: 271 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 272 else: 273 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 274 275 loss, metric = None, None 276 # No ignore label -> we can use default loss. 277 if ignore_label is None and not mask_channel: 278 pass 279 # If we have an ignore label the loss and metric have to be modified 280 # so that the ignore mask is not used in the gradient calculation. 281 elif ignore_label is not None: 282 loss = torch_em.loss.LossWrapper( 283 loss=torch_em.loss.DiceLoss(), 284 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 285 ignore_label=ignore_label, masking_method="multiply", 286 ) 287 ) 288 metric = loss 289 elif mask_channel: 290 loss = torch_em.loss.LossWrapper( 291 loss=torch_em.loss.DiceLoss(), 292 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 293 masking_method="crop" if out_channels == 1 else "multiply") 294 ) 295 metric = loss 296 else: 297 raise ValueError 298 299 trainer = torch_em.default_segmentation_trainer( 300 name=name, 301 model=model, 302 train_loader=train_loader, 303 val_loader=val_loader, 304 learning_rate=lr, 305 mixed_precision=True, 306 log_image_interval=100, 307 compile_model=False, 308 save_root=save_root, 309 loss=loss, 310 metric=metric, 311 ) 312 trainer.fit(n_iterations) 313 314 315def _derive_key_from_files(files, key): 316 # Get all file extensions (general wild-cards may pick up files with multiple extensions). 317 extensions = list(set([os.path.splitext(ff)[1] for ff in files])) 318 319 # If we have more than 1 file extension we just use the key that was passed, 320 # as it is unclear how to derive a consistent key. 321 if len(extensions) > 1: 322 return files, key 323 324 ext = extensions[0] 325 extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} 326 327 # Derive the key from the extension if the key is None. 328 if key is None and ext in extension_to_key: 329 key = extension_to_key[ext] 330 # If the key is None and can't be derived raise an error. 331 elif key is None and ext not in extension_to_key: 332 raise ValueError( 333 f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." 334 ) 335 # If the key was passed and doesn't match the extension raise an error. 336 elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: 337 raise ValueError( 338 f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." 339 ) 340 return files, key 341 342 343def _parse_input_folder(folder, pattern, key): 344 files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) 345 return _derive_key_from_files(files, key) 346 347 348def _parse_input_files(args): 349 train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) 350 train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) 351 if len(train_image_paths) != len(train_label_paths): 352 raise ValueError( 353 f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." 354 f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." 355 ) 356 357 if args.val_folder is None: 358 if args.val_label_folder is not None: 359 raise ValueError("You have passed a val_label_folder, but not a val_folder.") 360 train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( 361 train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 362 ) 363 else: 364 if args.val_label_folder is None: 365 raise ValueError("You have passed a val_folder, but not a val_label_folder.") 366 val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key) 367 val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) 368 369 return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key 370 371 372def _parse_checkpoint(initial_model): 373 if initial_model is None: 374 return None 375 if os.path.exists(initial_model): 376 return initial_model 377 model_path = get_model_path(initial_model) 378 return model_path 379 380 381def main(): 382 """@private 383 """ 384 import argparse 385 386 parser = argparse.ArgumentParser( 387 description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" 388 "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa 389 "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa 390 "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa 391 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 392 "Check out the information below for details on the arguments of this function.", 393 formatter_class=argparse.RawTextHelpFormatter 394 ) 395 parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") 396 parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") 397 398 # Folders with training data, containing raw/image data and labels. 399 parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") 400 parser.add_argument("--image_file_pattern", default="*", 401 help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") 402 parser.add_argument("--raw_key", 403 help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa 404 parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") 405 parser.add_argument("--label_file_pattern", default="*", 406 help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") 407 parser.add_argument("--label_key", 408 help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa 409 410 # Optional folders with validation data. If not given the training data is split into train/val. 411 parser.add_argument("--val_folder", 412 help="The input folder with the validation data. If not given the training data will be split for validation") # noqa 413 parser.add_argument("--val_label_folder", 414 help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa 415 416 # Optional: choose a model for initializing the weights. 417 available_models = get_available_models() 418 parser.add_argument( 419 "--initial_model", 420 help="Choose a model checkpoint for weight initialization.\n" 421 "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n" 422 f"The following pretrained models are available: {available_models}.\n" 423 "If not given, the model will be randomly initialized." 424 ) 425 426 # More optional argument: 427 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 428 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 429 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 430 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 431 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 432 parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.") # noqa 433 parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.") 434 args = parser.parse_args() 435 436 train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ 437 _parse_input_files(args) 438 checkpoint_path = _parse_checkpoint(args.initial_model) 439 440 supervised_training( 441 name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, 442 train_label_paths=train_label_paths, val_label_paths=val_label_paths, 443 raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, 444 n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, 445 check=args.check, n_iterations=args.n_iterations, save_root=args.save_root, 446 checkpoint_path=checkpoint_path 447 )
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
14def get_3d_model( 15 out_channels: int, 16 in_channels: int = 1, 17 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 18 initial_features: int = 32, 19 final_activation: str = "Sigmoid", 20) -> torch.nn.Module: 21 """Get the U-Net model for 3D segmentation tasks. 22 23 Args: 24 out_channels: The number of output channels of the network. 25 scale_factors: The downscaling factors for each level of the U-Net encoder. 26 initial_features: The number of features in the first level of the U-Net. 27 The number of features increases by a factor of two in each level. 28 final_activation: The activation applied to the last output layer. 29 30 Returns: 31 The U-Net. 32 """ 33 model = AnisotropicUNet( 34 scale_factors=scale_factors, 35 in_channels=in_channels, 36 out_channels=out_channels, 37 initial_features=initial_features, 38 gain=2, 39 final_activation=final_activation, 40 ) 41 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
44def get_2d_model( 45 out_channels: int, 46 in_channels: int = 1, 47 initial_features: int = 32, 48 final_activation: str = "Sigmoid", 49) -> torch.nn.Module: 50 """Get the U-Net model for 2D segmentation tasks. 51 52 Args: 53 out_channels: The number of output channels of the network. 54 initial_features: The number of features in the first level of the U-Net. 55 The number of features increases by a factor of two in each level. 56 final_activation: The activation applied to the last output layer. 57 58 Returns: 59 The U-Net. 60 """ 61 model = UNet2d( 62 in_channels=in_channels, 63 out_channels=out_channels, 64 initial_features=initial_features, 65 gain=2, 66 depth=4, 67 final_activation=final_activation, 68 ) 69 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
91def get_supervised_loader( 92 data_paths: Tuple[str], 93 raw_key: str, 94 label_key: str, 95 patch_shape: Tuple[int, int, int], 96 batch_size: int, 97 n_samples: Optional[int], 98 add_boundary_transform: bool = True, 99 label_dtype=torch.float32, 100 rois: Optional[Tuple[Tuple[slice]]] = None, 101 sampler: Optional[callable] = None, 102 ignore_label: Optional[int] = None, 103 label_transform: Optional[callable] = None, 104 label_paths: Optional[Tuple[str]] = None, 105 **loader_kwargs, 106) -> torch.utils.data.DataLoader: 107 """Get a dataloader for supervised segmentation training. 108 109 Args: 110 data_paths: The filepaths to the hdf5 files containing the training data. 111 raw_key: The key that holds the raw data inside of the hdf5. 112 label_key: The key that holds the labels inside of the hdf5. 113 patch_shape: The patch shape used for a training example. 114 In order to run 2d training pass a patch shape with a singleton in the z-axis, 115 e.g. 'patch_shape = [1, 512, 512]'. 116 batch_size: The batch size for training. 117 n_samples: The number of samples per epoch. By default this will be estimated 118 based on the patch_shape and size of the volumes used for training. 119 add_boundary_transform: Whether to add a boundary channel to the training data. 120 label_dtype: The datatype of the labels returned by the dataloader. 121 rois: Optional region of interests for training. 122 sampler: Optional sampler for selecting blocks for training. 123 By default a minimum instance sampler will be used. 124 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 125 ignored in the loss computation. By default this option is not used. 126 label_transform: Label transform that is applied to the segmentation to compute the targets. 127 If no label transform is passed (the default) a boundary transform is used. 128 label_paths: Optional paths containing the labels / annotations for training. 129 If not given, the labels are expected to be contained in the `data_paths`. 130 loader_kwargs: Additional keyword arguments for the dataloader. 131 132 Returns: 133 The PyTorch dataloader. 134 """ 135 _, ndim = _determine_ndim(patch_shape) 136 if label_transform is not None: # A specific label transform was passed, do nothing. 137 pass 138 elif add_boundary_transform: 139 if ignore_label is None: 140 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 141 else: 142 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 143 add_binary_target=True, ignore_label=ignore_label 144 ) 145 146 else: 147 if ignore_label is not None: 148 raise NotImplementedError 149 label_transform = torch_em.transform.label.connected_components 150 151 if ndim == 2: 152 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 153 transform = torch_em.transform.Compose( 154 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 155 ) 156 else: 157 transform = torch_em.transform.Compose( 158 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 159 ) 160 161 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 162 shuffle = loader_kwargs.pop("shuffle", True) 163 164 if sampler is None: 165 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 166 167 if label_paths is None: 168 label_paths = data_paths 169 elif len(label_paths) != len(data_paths): 170 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 171 172 loader = torch_em.default_segmentation_loader( 173 data_paths, raw_key, 174 label_paths, label_key, sampler=sampler, 175 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 176 is_seg_dataset=True, label_transform=label_transform, transform=transform, 177 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 178 label_dtype=label_dtype, rois=rois, **loader_kwargs, 179 ) 180 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- label_paths: Optional paths containing the labels / annotations for training.
If not given, the labels are expected to be contained in the
data_paths. - loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, **loader_kwargs):
183def supervised_training( 184 name: str, 185 train_paths: Tuple[str], 186 val_paths: Tuple[str], 187 label_key: str, 188 patch_shape: Tuple[int, int, int], 189 save_root: Optional[str] = None, 190 raw_key: str = "raw", 191 batch_size: int = 1, 192 lr: float = 1e-4, 193 n_iterations: int = int(1e5), 194 train_label_paths: Optional[Tuple[str]] = None, 195 val_label_paths: Optional[Tuple[str]] = None, 196 train_rois: Optional[Tuple[Tuple[slice]]] = None, 197 val_rois: Optional[Tuple[Tuple[slice]]] = None, 198 sampler: Optional[callable] = None, 199 n_samples_train: Optional[int] = None, 200 n_samples_val: Optional[int] = None, 201 check: bool = False, 202 ignore_label: Optional[int] = None, 203 label_transform: Optional[callable] = None, 204 in_channels: int = 1, 205 out_channels: int = 2, 206 mask_channel: bool = False, 207 checkpoint_path: Optional[str] = None, 208 **loader_kwargs, 209): 210 """Run supervised segmentation training. 211 212 This function trains a UNet for predicting outputs for segmentation. 213 Expects instance labels and converts them to boundary targets. 214 This behaviour can be changed by passing custom arguments for `label_transform` 215 and/or `out_channels`. 216 217 Args: 218 name: The name for the checkpoint to be trained. 219 train_paths: Filepaths to the hdf5 files for the training data. 220 val_paths: Filepaths to the df5 files for the validation data. 221 label_key: The key that holds the labels inside of the hdf5. 222 patch_shape: The patch shape used for a training example. 223 In order to run 2d training pass a patch shape with a singleton in the z-axis, 224 e.g. 'patch_shape = [1, 512, 512]'. 225 save_root: Folder where the checkpoint will be saved. 226 raw_key: The key that holds the raw data inside of the hdf5. 227 batch_size: The batch size for training. 228 lr: The initial learning rate. 229 n_iterations: The number of iterations to train for. 230 train_label_paths: Optional paths containing the label data for training. 231 If not given, the labels are expected to be part of `train_paths`. 232 val_label_paths: Optional paths containing the label data for validation. 233 If not given, the labels are expected to be part of `val_paths`. 234 train_rois: Optional region of interests for training. 235 val_rois: Optional region of interests for validation. 236 sampler: Optional sampler for selecting blocks for training. 237 By default a minimum instance sampler will be used. 238 n_samples_train: The number of train samples per epoch. By default this will be estimated 239 based on the patch_shape and size of the volumes used for training. 240 n_samples_val: The number of val samples per epoch. By default this will be estimated 241 based on the patch_shape and size of the volumes used for validation. 242 check: Whether to check the training and validation loaders instead of running training. 243 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 244 ignored in the loss computation. By default this option is not used. 245 label_transform: Label transform that is applied to the segmentation to compute the targets. 246 If no label transform is passed (the default) a boundary transform is used. 247 out_channels: The number of output channels of the UNet. 248 mask_channel: Whether the last channels in the labels should be used for masking the loss. 249 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 250 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 251 loader_kwargs: Additional keyword arguments for the dataloader. 252 """ 253 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 254 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 255 ignore_label=ignore_label, label_transform=label_transform, 256 label_paths=train_label_paths, **loader_kwargs) 257 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 258 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 259 ignore_label=ignore_label, label_transform=label_transform, 260 label_paths=val_label_paths, **loader_kwargs) 261 262 if check: 263 from torch_em.util.debug import check_loader 264 check_loader(train_loader, n_samples=4) 265 check_loader(val_loader, n_samples=4) 266 return 267 268 is_2d, _ = _determine_ndim(patch_shape) 269 if checkpoint_path is not None: 270 model = torch_em.util.load_model(checkpoint=checkpoint_path) 271 elif is_2d: 272 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 273 else: 274 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 275 276 loss, metric = None, None 277 # No ignore label -> we can use default loss. 278 if ignore_label is None and not mask_channel: 279 pass 280 # If we have an ignore label the loss and metric have to be modified 281 # so that the ignore mask is not used in the gradient calculation. 282 elif ignore_label is not None: 283 loss = torch_em.loss.LossWrapper( 284 loss=torch_em.loss.DiceLoss(), 285 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 286 ignore_label=ignore_label, masking_method="multiply", 287 ) 288 ) 289 metric = loss 290 elif mask_channel: 291 loss = torch_em.loss.LossWrapper( 292 loss=torch_em.loss.DiceLoss(), 293 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 294 masking_method="crop" if out_channels == 1 else "multiply") 295 ) 296 metric = loss 297 else: 298 raise ValueError 299 300 trainer = torch_em.default_segmentation_trainer( 301 name=name, 302 model=model, 303 train_loader=train_loader, 304 val_loader=val_loader, 305 learning_rate=lr, 306 mixed_precision=True, 307 log_image_interval=100, 308 compile_model=False, 309 save_root=save_root, 310 loss=loss, 311 metric=metric, 312 ) 313 trainer.fit(n_iterations)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_label_paths: Optional paths containing the label data for training.
If not given, the labels are expected to be part of
train_paths. - val_label_paths: Optional paths containing the label data for validation.
If not given, the labels are expected to be part of
val_paths. - train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label. - checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
- loader_kwargs: Additional keyword arguments for the dataloader.