synapse_net.training.supervised_training
1import os 2from glob import glob 3from typing import Optional, Tuple 4 5import torch 6import torch_em 7from sklearn.model_selection import train_test_split 8from torch_em.model import AnisotropicUNet, UNet2d 9 10 11def get_3d_model( 12 out_channels: int, 13 in_channels: int = 1, 14 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 15 initial_features: int = 32, 16 final_activation: str = "Sigmoid", 17) -> torch.nn.Module: 18 """Get the U-Net model for 3D segmentation tasks. 19 20 Args: 21 out_channels: The number of output channels of the network. 22 scale_factors: The downscaling factors for each level of the U-Net encoder. 23 initial_features: The number of features in the first level of the U-Net. 24 The number of features increases by a factor of two in each level. 25 final_activation: The activation applied to the last output layer. 26 27 Returns: 28 The U-Net. 29 """ 30 model = AnisotropicUNet( 31 scale_factors=scale_factors, 32 in_channels=in_channels, 33 out_channels=out_channels, 34 initial_features=initial_features, 35 gain=2, 36 final_activation=final_activation, 37 ) 38 return model 39 40 41def get_2d_model( 42 out_channels: int, 43 in_channels: int = 1, 44 initial_features: int = 32, 45 final_activation: str = "Sigmoid", 46) -> torch.nn.Module: 47 """Get the U-Net model for 2D segmentation tasks. 48 49 Args: 50 out_channels: The number of output channels of the network. 51 initial_features: The number of features in the first level of the U-Net. 52 The number of features increases by a factor of two in each level. 53 final_activation: The activation applied to the last output layer. 54 55 Returns: 56 The U-Net. 57 """ 58 model = UNet2d( 59 in_channels=in_channels, 60 out_channels=out_channels, 61 initial_features=initial_features, 62 gain=2, 63 depth=4, 64 final_activation=final_activation, 65 ) 66 return model 67 68 69def _adjust_patch_shape(data_shape, patch_shape): 70 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 71 if data_shape == 2 and len(patch_shape) == 3: 72 return patch_shape[1:] # Remove the leading dimension in patch_shape 73 return patch_shape # Return the original patch_shape for 3D data 74 75 76def _determine_ndim(patch_shape): 77 # Check for 2D or 3D training 78 try: 79 z, y, x = patch_shape 80 except ValueError: 81 y, x = patch_shape 82 z = 1 83 is_2d = z == 1 84 ndim = 2 if is_2d else 3 85 return is_2d, ndim 86 87 88def get_supervised_loader( 89 data_paths: Tuple[str], 90 raw_key: str, 91 label_key: str, 92 patch_shape: Tuple[int, int, int], 93 batch_size: int, 94 n_samples: Optional[int], 95 add_boundary_transform: bool = True, 96 label_dtype=torch.float32, 97 rois: Optional[Tuple[Tuple[slice]]] = None, 98 sampler: Optional[callable] = None, 99 ignore_label: Optional[int] = None, 100 label_transform: Optional[callable] = None, 101 label_paths: Optional[Tuple[str]] = None, 102 **loader_kwargs, 103) -> torch.utils.data.DataLoader: 104 """Get a dataloader for supervised segmentation training. 105 106 Args: 107 data_paths: The filepaths to the hdf5 files containing the training data. 108 raw_key: The key that holds the raw data inside of the hdf5. 109 label_key: The key that holds the labels inside of the hdf5. 110 patch_shape: The patch shape used for a training example. 111 In order to run 2d training pass a patch shape with a singleton in the z-axis, 112 e.g. 'patch_shape = [1, 512, 512]'. 113 batch_size: The batch size for training. 114 n_samples: The number of samples per epoch. By default this will be estimated 115 based on the patch_shape and size of the volumes used for training. 116 add_boundary_transform: Whether to add a boundary channel to the training data. 117 label_dtype: The datatype of the labels returned by the dataloader. 118 rois: Optional region of interests for training. 119 sampler: Optional sampler for selecting blocks for training. 120 By default a minimum instance sampler will be used. 121 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 122 ignored in the loss computation. By default this option is not used. 123 label_transform: Label transform that is applied to the segmentation to compute the targets. 124 If no label transform is passed (the default) a boundary transform is used. 125 label_paths: Optional paths containing the labels / annotations for training. 126 If not given, the labels are expected to be contained in the `data_paths`. 127 loader_kwargs: Additional keyword arguments for the dataloader. 128 129 Returns: 130 The PyTorch dataloader. 131 """ 132 _, ndim = _determine_ndim(patch_shape) 133 if label_transform is not None: # A specific label transform was passed, do nothing. 134 pass 135 elif add_boundary_transform: 136 if ignore_label is None: 137 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 138 else: 139 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 140 add_binary_target=True, ignore_label=ignore_label 141 ) 142 143 else: 144 if ignore_label is not None: 145 raise NotImplementedError 146 label_transform = torch_em.transform.label.connected_components 147 148 if ndim == 2: 149 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 150 transform = torch_em.transform.Compose( 151 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 152 ) 153 else: 154 transform = torch_em.transform.Compose( 155 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 156 ) 157 158 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 159 shuffle = loader_kwargs.pop("shuffle", True) 160 161 if sampler is None: 162 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 163 164 if label_paths is None: 165 label_paths = data_paths 166 elif len(label_paths) != len(data_paths): 167 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 168 169 loader = torch_em.default_segmentation_loader( 170 data_paths, raw_key, 171 label_paths, label_key, sampler=sampler, 172 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 173 is_seg_dataset=True, label_transform=label_transform, transform=transform, 174 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 175 label_dtype=label_dtype, rois=rois, **loader_kwargs, 176 ) 177 return loader 178 179 180def supervised_training( 181 name: str, 182 train_paths: Tuple[str], 183 val_paths: Tuple[str], 184 label_key: str, 185 patch_shape: Tuple[int, int, int], 186 save_root: Optional[str] = None, 187 raw_key: str = "raw", 188 batch_size: int = 1, 189 lr: float = 1e-4, 190 n_iterations: int = int(1e5), 191 train_label_paths: Optional[Tuple[str]] = None, 192 val_label_paths: Optional[Tuple[str]] = None, 193 train_rois: Optional[Tuple[Tuple[slice]]] = None, 194 val_rois: Optional[Tuple[Tuple[slice]]] = None, 195 sampler: Optional[callable] = None, 196 n_samples_train: Optional[int] = None, 197 n_samples_val: Optional[int] = None, 198 check: bool = False, 199 ignore_label: Optional[int] = None, 200 label_transform: Optional[callable] = None, 201 in_channels: int = 1, 202 out_channels: int = 2, 203 mask_channel: bool = False, 204 checkpoint_path: Optional[str] = None, 205 **loader_kwargs, 206): 207 """Run supervised segmentation training. 208 209 This function trains a UNet for predicting outputs for segmentation. 210 Expects instance labels and converts them to boundary targets. 211 This behaviour can be changed by passing custom arguments for `label_transform` 212 and/or `out_channels`. 213 214 Args: 215 name: The name for the checkpoint to be trained. 216 train_paths: Filepaths to the hdf5 files for the training data. 217 val_paths: Filepaths to the df5 files for the validation data. 218 label_key: The key that holds the labels inside of the hdf5. 219 patch_shape: The patch shape used for a training example. 220 In order to run 2d training pass a patch shape with a singleton in the z-axis, 221 e.g. 'patch_shape = [1, 512, 512]'. 222 save_root: Folder where the checkpoint will be saved. 223 raw_key: The key that holds the raw data inside of the hdf5. 224 batch_size: The batch size for training. 225 lr: The initial learning rate. 226 n_iterations: The number of iterations to train for. 227 train_label_paths: Optional paths containing the label data for training. 228 If not given, the labels are expected to be part of `train_paths`. 229 val_label_paths: Optional paths containing the label data for validation. 230 If not given, the labels are expected to be part of `val_paths`. 231 train_rois: Optional region of interests for training. 232 val_rois: Optional region of interests for validation. 233 sampler: Optional sampler for selecting blocks for training. 234 By default a minimum instance sampler will be used. 235 n_samples_train: The number of train samples per epoch. By default this will be estimated 236 based on the patch_shape and size of the volumes used for training. 237 n_samples_val: The number of val samples per epoch. By default this will be estimated 238 based on the patch_shape and size of the volumes used for validation. 239 check: Whether to check the training and validation loaders instead of running training. 240 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 241 ignored in the loss computation. By default this option is not used. 242 label_transform: Label transform that is applied to the segmentation to compute the targets. 243 If no label transform is passed (the default) a boundary transform is used. 244 out_channels: The number of output channels of the UNet. 245 mask_channel: Whether the last channels in the labels should be used for masking the loss. 246 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 247 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 248 loader_kwargs: Additional keyword arguments for the dataloader. 249 """ 250 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 251 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 252 ignore_label=ignore_label, label_transform=label_transform, 253 label_paths=train_label_paths, **loader_kwargs) 254 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 255 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 256 ignore_label=ignore_label, label_transform=label_transform, 257 label_paths=val_label_paths, **loader_kwargs) 258 259 if check: 260 from torch_em.util.debug import check_loader 261 check_loader(train_loader, n_samples=4) 262 check_loader(val_loader, n_samples=4) 263 return 264 265 is_2d, _ = _determine_ndim(patch_shape) 266 if is_2d: 267 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 268 else: 269 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 270 271 if checkpoint_path: 272 model = torch_em.util.load_model(checkpoint=checkpoint_path) 273 274 loss, metric = None, None 275 # No ignore label -> we can use default loss. 276 if ignore_label is None and not mask_channel: 277 pass 278 # If we have an ignore label the loss and metric have to be modified 279 # so that the ignore mask is not used in the gradient calculation. 280 elif ignore_label is not None: 281 loss = torch_em.loss.LossWrapper( 282 loss=torch_em.loss.DiceLoss(), 283 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 284 ignore_label=ignore_label, masking_method="multiply", 285 ) 286 ) 287 metric = loss 288 elif mask_channel: 289 loss = torch_em.loss.LossWrapper( 290 loss=torch_em.loss.DiceLoss(), 291 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 292 masking_method="crop" if out_channels == 1 else "multiply") 293 ) 294 metric = loss 295 else: 296 raise ValueError 297 298 trainer = torch_em.default_segmentation_trainer( 299 name=name, 300 model=model, 301 train_loader=train_loader, 302 val_loader=val_loader, 303 learning_rate=lr, 304 mixed_precision=True, 305 log_image_interval=100, 306 compile_model=False, 307 save_root=save_root, 308 loss=loss, 309 metric=metric, 310 ) 311 trainer.fit(n_iterations) 312 313 314def _derive_key_from_files(files, key): 315 # Get all file extensions (general wild-cards may pick up files with multiple extensions). 316 extensions = list(set([os.path.splitext(ff)[1] for ff in files])) 317 318 # If we have more than 1 file extension we just use the key that was passed, 319 # as it is unclear how to derive a consistent key. 320 if len(extensions) > 1: 321 return files, key 322 323 ext = extensions[0] 324 extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"} 325 326 # Derive the key from the extension if the key is None. 327 if key is None and ext in extension_to_key: 328 key = extension_to_key[ext] 329 # If the key is None and can't be derived raise an error. 330 elif key is None and ext not in extension_to_key: 331 raise ValueError( 332 f"You have not passed a key for the data in {ext} format, for which the key cannot be derived." 333 ) 334 # If the key was passed and doesn't match the extension raise an error. 335 elif key is not None and ext in extension_to_key and key != extension_to_key[ext]: 336 raise ValueError( 337 f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}." 338 ) 339 return files, key 340 341 342def _parse_input_folder(folder, pattern, key): 343 files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True)) 344 return _derive_key_from_files(files, key) 345 346 347def _parse_input_files(args): 348 train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key) 349 train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key) 350 if len(train_image_paths) != len(train_label_paths): 351 raise ValueError( 352 f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match." 353 f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}." 354 ) 355 356 if args.val_folder is None: 357 if args.val_label_folder is not None: 358 raise ValueError("You have passed a val_label_folder, but not a val_folder.") 359 train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split( 360 train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42 361 ) 362 else: 363 if args.val_label_folder is None: 364 raise ValueError("You have passed a val_folder, but not a val_label_folder.") 365 val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key) 366 val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key) 367 368 return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key 369 370 371# TODO enable initialization with a pre-trained model. 372def main(): 373 """@private 374 """ 375 import argparse 376 377 parser = argparse.ArgumentParser( 378 description="Train a model for foreground and boundary segmentation via supervised learning.\n\n" 379 "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n" # noqa 380 "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n" # noqa 381 "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)." # noqa 382 "You can then use this model for segmentation with the SynapseNet GUI or CLI. " 383 "Check out the information below for details on the arguments of this function.", 384 formatter_class=argparse.RawTextHelpFormatter 385 ) 386 parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.") 387 parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.") 388 389 # Folders with training data, containing raw/image data and labels. 390 parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.") 391 parser.add_argument("--image_file_pattern", default="*", 392 help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.") 393 parser.add_argument("--raw_key", 394 help="The internal path for the raw data. If not given, will be determined based on the file extension.") # noqa 395 parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.") 396 parser.add_argument("--label_file_pattern", default="*", 397 help="The pattern for selecting label files. For example, '*.tif' to select all tif files.") 398 parser.add_argument("--label_key", 399 help="The internal path for the label data. If not given, will be determined based on the file extension.") # noqa 400 401 # Optional folders with validation data. If not given the training data is split into train/val. 402 parser.add_argument("--val_folder", 403 help="The input folder with the validation data. If not given the training data will be split for validation") # noqa 404 parser.add_argument("--val_label_folder", 405 help="The input folder with the validation labels. If not given the training data will be split for validation.") # noqa 406 407 # More optional argument: 408 parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.") 409 parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.") # noqa 410 parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.") # noqa 411 parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.") # noqa 412 parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.") # noqa 413 args = parser.parse_args() 414 415 train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\ 416 _parse_input_files(args) 417 418 supervised_training( 419 name=args.name, train_paths=train_image_paths, val_paths=val_image_paths, 420 train_label_paths=train_label_paths, val_label_paths=val_label_paths, 421 raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size, 422 n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val, 423 check=args.check, 424 )
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
12def get_3d_model( 13 out_channels: int, 14 in_channels: int = 1, 15 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 16 initial_features: int = 32, 17 final_activation: str = "Sigmoid", 18) -> torch.nn.Module: 19 """Get the U-Net model for 3D segmentation tasks. 20 21 Args: 22 out_channels: The number of output channels of the network. 23 scale_factors: The downscaling factors for each level of the U-Net encoder. 24 initial_features: The number of features in the first level of the U-Net. 25 The number of features increases by a factor of two in each level. 26 final_activation: The activation applied to the last output layer. 27 28 Returns: 29 The U-Net. 30 """ 31 model = AnisotropicUNet( 32 scale_factors=scale_factors, 33 in_channels=in_channels, 34 out_channels=out_channels, 35 initial_features=initial_features, 36 gain=2, 37 final_activation=final_activation, 38 ) 39 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
42def get_2d_model( 43 out_channels: int, 44 in_channels: int = 1, 45 initial_features: int = 32, 46 final_activation: str = "Sigmoid", 47) -> torch.nn.Module: 48 """Get the U-Net model for 2D segmentation tasks. 49 50 Args: 51 out_channels: The number of output channels of the network. 52 initial_features: The number of features in the first level of the U-Net. 53 The number of features increases by a factor of two in each level. 54 final_activation: The activation applied to the last output layer. 55 56 Returns: 57 The U-Net. 58 """ 59 model = UNet2d( 60 in_channels=in_channels, 61 out_channels=out_channels, 62 initial_features=initial_features, 63 gain=2, 64 depth=4, 65 final_activation=final_activation, 66 ) 67 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
89def get_supervised_loader( 90 data_paths: Tuple[str], 91 raw_key: str, 92 label_key: str, 93 patch_shape: Tuple[int, int, int], 94 batch_size: int, 95 n_samples: Optional[int], 96 add_boundary_transform: bool = True, 97 label_dtype=torch.float32, 98 rois: Optional[Tuple[Tuple[slice]]] = None, 99 sampler: Optional[callable] = None, 100 ignore_label: Optional[int] = None, 101 label_transform: Optional[callable] = None, 102 label_paths: Optional[Tuple[str]] = None, 103 **loader_kwargs, 104) -> torch.utils.data.DataLoader: 105 """Get a dataloader for supervised segmentation training. 106 107 Args: 108 data_paths: The filepaths to the hdf5 files containing the training data. 109 raw_key: The key that holds the raw data inside of the hdf5. 110 label_key: The key that holds the labels inside of the hdf5. 111 patch_shape: The patch shape used for a training example. 112 In order to run 2d training pass a patch shape with a singleton in the z-axis, 113 e.g. 'patch_shape = [1, 512, 512]'. 114 batch_size: The batch size for training. 115 n_samples: The number of samples per epoch. By default this will be estimated 116 based on the patch_shape and size of the volumes used for training. 117 add_boundary_transform: Whether to add a boundary channel to the training data. 118 label_dtype: The datatype of the labels returned by the dataloader. 119 rois: Optional region of interests for training. 120 sampler: Optional sampler for selecting blocks for training. 121 By default a minimum instance sampler will be used. 122 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 123 ignored in the loss computation. By default this option is not used. 124 label_transform: Label transform that is applied to the segmentation to compute the targets. 125 If no label transform is passed (the default) a boundary transform is used. 126 label_paths: Optional paths containing the labels / annotations for training. 127 If not given, the labels are expected to be contained in the `data_paths`. 128 loader_kwargs: Additional keyword arguments for the dataloader. 129 130 Returns: 131 The PyTorch dataloader. 132 """ 133 _, ndim = _determine_ndim(patch_shape) 134 if label_transform is not None: # A specific label transform was passed, do nothing. 135 pass 136 elif add_boundary_transform: 137 if ignore_label is None: 138 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 139 else: 140 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 141 add_binary_target=True, ignore_label=ignore_label 142 ) 143 144 else: 145 if ignore_label is not None: 146 raise NotImplementedError 147 label_transform = torch_em.transform.label.connected_components 148 149 if ndim == 2: 150 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 151 transform = torch_em.transform.Compose( 152 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 153 ) 154 else: 155 transform = torch_em.transform.Compose( 156 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 157 ) 158 159 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 160 shuffle = loader_kwargs.pop("shuffle", True) 161 162 if sampler is None: 163 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 164 165 if label_paths is None: 166 label_paths = data_paths 167 elif len(label_paths) != len(data_paths): 168 raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}") 169 170 loader = torch_em.default_segmentation_loader( 171 data_paths, raw_key, 172 label_paths, label_key, sampler=sampler, 173 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 174 is_seg_dataset=True, label_transform=label_transform, transform=transform, 175 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 176 label_dtype=label_dtype, rois=rois, **loader_kwargs, 177 ) 178 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- label_paths: Optional paths containing the labels / annotations for training.
If not given, the labels are expected to be contained in the
data_paths. - loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, **loader_kwargs):
181def supervised_training( 182 name: str, 183 train_paths: Tuple[str], 184 val_paths: Tuple[str], 185 label_key: str, 186 patch_shape: Tuple[int, int, int], 187 save_root: Optional[str] = None, 188 raw_key: str = "raw", 189 batch_size: int = 1, 190 lr: float = 1e-4, 191 n_iterations: int = int(1e5), 192 train_label_paths: Optional[Tuple[str]] = None, 193 val_label_paths: Optional[Tuple[str]] = None, 194 train_rois: Optional[Tuple[Tuple[slice]]] = None, 195 val_rois: Optional[Tuple[Tuple[slice]]] = None, 196 sampler: Optional[callable] = None, 197 n_samples_train: Optional[int] = None, 198 n_samples_val: Optional[int] = None, 199 check: bool = False, 200 ignore_label: Optional[int] = None, 201 label_transform: Optional[callable] = None, 202 in_channels: int = 1, 203 out_channels: int = 2, 204 mask_channel: bool = False, 205 checkpoint_path: Optional[str] = None, 206 **loader_kwargs, 207): 208 """Run supervised segmentation training. 209 210 This function trains a UNet for predicting outputs for segmentation. 211 Expects instance labels and converts them to boundary targets. 212 This behaviour can be changed by passing custom arguments for `label_transform` 213 and/or `out_channels`. 214 215 Args: 216 name: The name for the checkpoint to be trained. 217 train_paths: Filepaths to the hdf5 files for the training data. 218 val_paths: Filepaths to the df5 files for the validation data. 219 label_key: The key that holds the labels inside of the hdf5. 220 patch_shape: The patch shape used for a training example. 221 In order to run 2d training pass a patch shape with a singleton in the z-axis, 222 e.g. 'patch_shape = [1, 512, 512]'. 223 save_root: Folder where the checkpoint will be saved. 224 raw_key: The key that holds the raw data inside of the hdf5. 225 batch_size: The batch size for training. 226 lr: The initial learning rate. 227 n_iterations: The number of iterations to train for. 228 train_label_paths: Optional paths containing the label data for training. 229 If not given, the labels are expected to be part of `train_paths`. 230 val_label_paths: Optional paths containing the label data for validation. 231 If not given, the labels are expected to be part of `val_paths`. 232 train_rois: Optional region of interests for training. 233 val_rois: Optional region of interests for validation. 234 sampler: Optional sampler for selecting blocks for training. 235 By default a minimum instance sampler will be used. 236 n_samples_train: The number of train samples per epoch. By default this will be estimated 237 based on the patch_shape and size of the volumes used for training. 238 n_samples_val: The number of val samples per epoch. By default this will be estimated 239 based on the patch_shape and size of the volumes used for validation. 240 check: Whether to check the training and validation loaders instead of running training. 241 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 242 ignored in the loss computation. By default this option is not used. 243 label_transform: Label transform that is applied to the segmentation to compute the targets. 244 If no label transform is passed (the default) a boundary transform is used. 245 out_channels: The number of output channels of the UNet. 246 mask_channel: Whether the last channels in the labels should be used for masking the loss. 247 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 248 checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model. 249 loader_kwargs: Additional keyword arguments for the dataloader. 250 """ 251 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 252 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 253 ignore_label=ignore_label, label_transform=label_transform, 254 label_paths=train_label_paths, **loader_kwargs) 255 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 256 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 257 ignore_label=ignore_label, label_transform=label_transform, 258 label_paths=val_label_paths, **loader_kwargs) 259 260 if check: 261 from torch_em.util.debug import check_loader 262 check_loader(train_loader, n_samples=4) 263 check_loader(val_loader, n_samples=4) 264 return 265 266 is_2d, _ = _determine_ndim(patch_shape) 267 if is_2d: 268 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 269 else: 270 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 271 272 if checkpoint_path: 273 model = torch_em.util.load_model(checkpoint=checkpoint_path) 274 275 loss, metric = None, None 276 # No ignore label -> we can use default loss. 277 if ignore_label is None and not mask_channel: 278 pass 279 # If we have an ignore label the loss and metric have to be modified 280 # so that the ignore mask is not used in the gradient calculation. 281 elif ignore_label is not None: 282 loss = torch_em.loss.LossWrapper( 283 loss=torch_em.loss.DiceLoss(), 284 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 285 ignore_label=ignore_label, masking_method="multiply", 286 ) 287 ) 288 metric = loss 289 elif mask_channel: 290 loss = torch_em.loss.LossWrapper( 291 loss=torch_em.loss.DiceLoss(), 292 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 293 masking_method="crop" if out_channels == 1 else "multiply") 294 ) 295 metric = loss 296 else: 297 raise ValueError 298 299 trainer = torch_em.default_segmentation_trainer( 300 name=name, 301 model=model, 302 train_loader=train_loader, 303 val_loader=val_loader, 304 learning_rate=lr, 305 mixed_precision=True, 306 log_image_interval=100, 307 compile_model=False, 308 save_root=save_root, 309 loss=loss, 310 metric=metric, 311 ) 312 trainer.fit(n_iterations)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_label_paths: Optional paths containing the label data for training.
If not given, the labels are expected to be part of
train_paths. - val_label_paths: Optional paths containing the label data for validation.
If not given, the labels are expected to be part of
val_paths. - train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label. - checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
- loader_kwargs: Additional keyword arguments for the dataloader.