synapse_net.training.supervised_training
1from typing import Optional, Tuple 2 3import torch 4import torch_em 5from torch_em.model import AnisotropicUNet, UNet2d 6 7 8def get_3d_model( 9 out_channels: int, 10 in_channels: int = 1, 11 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 12 initial_features: int = 32, 13 final_activation: str = "Sigmoid", 14) -> torch.nn.Module: 15 """Get the U-Net model for 3D segmentation tasks. 16 17 Args: 18 out_channels: The number of output channels of the network. 19 scale_factors: The downscaling factors for each level of the U-Net encoder. 20 initial_features: The number of features in the first level of the U-Net. 21 The number of features increases by a factor of two in each level. 22 final_activation: The activation applied to the last output layer. 23 24 Returns: 25 The U-Net. 26 """ 27 model = AnisotropicUNet( 28 scale_factors=scale_factors, 29 in_channels=in_channels, 30 out_channels=out_channels, 31 initial_features=initial_features, 32 gain=2, 33 final_activation=final_activation, 34 ) 35 return model 36 37 38def get_2d_model( 39 out_channels: int, 40 in_channels: int = 1, 41 initial_features: int = 32, 42 final_activation: str = "Sigmoid", 43) -> torch.nn.Module: 44 """Get the U-Net model for 2D segmentation tasks. 45 46 Args: 47 out_channels: The number of output channels of the network. 48 initial_features: The number of features in the first level of the U-Net. 49 The number of features increases by a factor of two in each level. 50 final_activation: The activation applied to the last output layer. 51 52 Returns: 53 The U-Net. 54 """ 55 model = UNet2d( 56 in_channels=in_channels, 57 out_channels=out_channels, 58 initial_features=initial_features, 59 gain=2, 60 depth=4, 61 final_activation=final_activation, 62 ) 63 return model 64 65 66def _adjust_patch_shape(data_shape, patch_shape): 67 # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape 68 if data_shape == 2 and len(patch_shape) == 3: 69 return patch_shape[1:] # Remove the leading dimension in patch_shape 70 return patch_shape # Return the original patch_shape for 3D data 71 72 73def _determine_ndim(patch_shape): 74 # Check for 2D or 3D training 75 try: 76 z, y, x = patch_shape 77 except ValueError: 78 y, x = patch_shape 79 z = 1 80 is_2d = z == 1 81 ndim = 2 if is_2d else 3 82 return is_2d, ndim 83 84 85def get_supervised_loader( 86 data_paths: Tuple[str], 87 raw_key: str, 88 label_key: str, 89 patch_shape: Tuple[int, int, int], 90 batch_size: int, 91 n_samples: Optional[int], 92 add_boundary_transform: bool = True, 93 label_dtype=torch.float32, 94 rois: Optional[Tuple[Tuple[slice]]] = None, 95 sampler: Optional[callable] = None, 96 ignore_label: Optional[int] = None, 97 label_transform: Optional[callable] = None, 98 **loader_kwargs, 99) -> torch.utils.data.DataLoader: 100 """Get a dataloader for supervised segmentation training. 101 102 Args: 103 data_paths: The filepaths to the hdf5 files containing the training data. 104 raw_key: The key that holds the raw data inside of the hdf5. 105 label_key: The key that holds the labels inside of the hdf5. 106 patch_shape: The patch shape used for a training example. 107 In order to run 2d training pass a patch shape with a singleton in the z-axis, 108 e.g. 'patch_shape = [1, 512, 512]'. 109 batch_size: The batch size for training. 110 n_samples: The number of samples per epoch. By default this will be estimated 111 based on the patch_shape and size of the volumes used for training. 112 add_boundary_transform: Whether to add a boundary channel to the training data. 113 label_dtype: The datatype of the labels returned by the dataloader. 114 rois: Optional region of interests for training. 115 sampler: Optional sampler for selecting blocks for training. 116 By default a minimum instance sampler will be used. 117 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 118 ignored in the loss computation. By default this option is not used. 119 label_transform: Label transform that is applied to the segmentation to compute the targets. 120 If no label transform is passed (the default) a boundary transform is used. 121 loader_kwargs: Additional keyword arguments for the dataloader. 122 123 Returns: 124 The PyTorch dataloader. 125 """ 126 _, ndim = _determine_ndim(patch_shape) 127 if label_transform is not None: # A specific label transform was passed, do nothing. 128 pass 129 elif add_boundary_transform: 130 if ignore_label is None: 131 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 132 else: 133 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 134 add_binary_target=True, ignore_label=ignore_label 135 ) 136 137 else: 138 if ignore_label is not None: 139 raise NotImplementedError 140 label_transform = torch_em.transform.label.connected_components 141 142 if ndim == 2: 143 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 144 transform = torch_em.transform.Compose( 145 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 146 ) 147 else: 148 transform = torch_em.transform.Compose( 149 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 150 ) 151 152 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 153 shuffle = loader_kwargs.pop("shuffle", True) 154 155 if sampler is None: 156 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 157 158 loader = torch_em.default_segmentation_loader( 159 data_paths, raw_key, 160 data_paths, label_key, sampler=sampler, 161 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 162 is_seg_dataset=True, label_transform=label_transform, transform=transform, 163 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 164 label_dtype=label_dtype, rois=rois, **loader_kwargs, 165 ) 166 return loader 167 168 169def supervised_training( 170 name: str, 171 train_paths: Tuple[str], 172 val_paths: Tuple[str], 173 label_key: str, 174 patch_shape: Tuple[int, int, int], 175 save_root: Optional[str] = None, 176 raw_key: str = "raw", 177 batch_size: int = 1, 178 lr: float = 1e-4, 179 n_iterations: int = int(1e5), 180 train_rois: Optional[Tuple[Tuple[slice]]] = None, 181 val_rois: Optional[Tuple[Tuple[slice]]] = None, 182 sampler: Optional[callable] = None, 183 n_samples_train: Optional[int] = None, 184 n_samples_val: Optional[int] = None, 185 check: bool = False, 186 ignore_label: Optional[int] = None, 187 label_transform: Optional[callable] = None, 188 in_channels: int = 1, 189 out_channels: int = 2, 190 mask_channel: bool = False, 191 **loader_kwargs, 192): 193 """Run supervised segmentation training. 194 195 This function trains a UNet for predicting outputs for segmentation. 196 Expects instance labels and converts them to boundary targets. 197 This behaviour can be changed by passing custom arguments for `label_transform` 198 and/or `out_channels`. 199 200 Args: 201 name: The name for the checkpoint to be trained. 202 train_paths: Filepaths to the hdf5 files for the training data. 203 val_paths: Filepaths to the df5 files for the validation data. 204 label_key: The key that holds the labels inside of the hdf5. 205 patch_shape: The patch shape used for a training example. 206 In order to run 2d training pass a patch shape with a singleton in the z-axis, 207 e.g. 'patch_shape = [1, 512, 512]'. 208 save_root: Folder where the checkpoint will be saved. 209 raw_key: The key that holds the raw data inside of the hdf5. 210 batch_size: The batch size for training. 211 lr: The initial learning rate. 212 n_iterations: The number of iterations to train for. 213 train_rois: Optional region of interests for training. 214 val_rois: Optional region of interests for validation. 215 sampler: Optional sampler for selecting blocks for training. 216 By default a minimum instance sampler will be used. 217 n_samples_train: The number of train samples per epoch. By default this will be estimated 218 based on the patch_shape and size of the volumes used for training. 219 n_samples_val: The number of val samples per epoch. By default this will be estimated 220 based on the patch_shape and size of the volumes used for validation. 221 check: Whether to check the training and validation loaders instead of running training. 222 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 223 ignored in the loss computation. By default this option is not used. 224 label_transform: Label transform that is applied to the segmentation to compute the targets. 225 If no label transform is passed (the default) a boundary transform is used. 226 out_channels: The number of output channels of the UNet. 227 mask_channel: Whether the last channels in the labels should be used for masking the loss. 228 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 229 loader_kwargs: Additional keyword arguments for the dataloader. 230 """ 231 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 232 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 233 ignore_label=ignore_label, label_transform=label_transform, 234 **loader_kwargs) 235 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 236 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 237 ignore_label=ignore_label, label_transform=label_transform, 238 **loader_kwargs) 239 240 if check: 241 from torch_em.util.debug import check_loader 242 check_loader(train_loader, n_samples=4) 243 check_loader(val_loader, n_samples=4) 244 return 245 246 is_2d, _ = _determine_ndim(patch_shape) 247 if is_2d: 248 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 249 else: 250 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 251 252 loss, metric = None, None 253 # No ignore label -> we can use default loss. 254 if ignore_label is None and not mask_channel: 255 pass 256 # If we have an ignore label the loss and metric have to be modified 257 # so that the ignore mask is not used in the gradient calculation. 258 elif ignore_label is not None: 259 loss = torch_em.loss.LossWrapper( 260 loss=torch_em.loss.DiceLoss(), 261 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 262 ignore_label=ignore_label, masking_method="multiply", 263 ) 264 ) 265 metric = loss 266 elif mask_channel: 267 loss = torch_em.loss.LossWrapper( 268 loss=torch_em.loss.DiceLoss(), 269 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 270 masking_method="crop" if out_channels == 1 else "multiply") 271 ) 272 metric = loss 273 else: 274 raise ValueError 275 276 trainer = torch_em.default_segmentation_trainer( 277 name=name, 278 model=model, 279 train_loader=train_loader, 280 val_loader=val_loader, 281 learning_rate=lr, 282 mixed_precision=True, 283 log_image_interval=100, 284 compile_model=False, 285 save_root=save_root, 286 loss=loss, 287 metric=metric, 288 ) 289 trainer.fit(n_iterations)
def
get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
9def get_3d_model( 10 out_channels: int, 11 in_channels: int = 1, 12 scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], 13 initial_features: int = 32, 14 final_activation: str = "Sigmoid", 15) -> torch.nn.Module: 16 """Get the U-Net model for 3D segmentation tasks. 17 18 Args: 19 out_channels: The number of output channels of the network. 20 scale_factors: The downscaling factors for each level of the U-Net encoder. 21 initial_features: The number of features in the first level of the U-Net. 22 The number of features increases by a factor of two in each level. 23 final_activation: The activation applied to the last output layer. 24 25 Returns: 26 The U-Net. 27 """ 28 model = AnisotropicUNet( 29 scale_factors=scale_factors, 30 in_channels=in_channels, 31 out_channels=out_channels, 32 initial_features=initial_features, 33 gain=2, 34 final_activation=final_activation, 35 ) 36 return model
Get the U-Net model for 3D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- scale_factors: The downscaling factors for each level of the U-Net encoder.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
39def get_2d_model( 40 out_channels: int, 41 in_channels: int = 1, 42 initial_features: int = 32, 43 final_activation: str = "Sigmoid", 44) -> torch.nn.Module: 45 """Get the U-Net model for 2D segmentation tasks. 46 47 Args: 48 out_channels: The number of output channels of the network. 49 initial_features: The number of features in the first level of the U-Net. 50 The number of features increases by a factor of two in each level. 51 final_activation: The activation applied to the last output layer. 52 53 Returns: 54 The U-Net. 55 """ 56 model = UNet2d( 57 in_channels=in_channels, 58 out_channels=out_channels, 59 initial_features=initial_features, 60 gain=2, 61 depth=4, 62 final_activation=final_activation, 63 ) 64 return model
Get the U-Net model for 2D segmentation tasks.
Arguments:
- out_channels: The number of output channels of the network.
- initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
- final_activation: The activation applied to the last output layer.
Returns:
The U-Net.
def
get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
86def get_supervised_loader( 87 data_paths: Tuple[str], 88 raw_key: str, 89 label_key: str, 90 patch_shape: Tuple[int, int, int], 91 batch_size: int, 92 n_samples: Optional[int], 93 add_boundary_transform: bool = True, 94 label_dtype=torch.float32, 95 rois: Optional[Tuple[Tuple[slice]]] = None, 96 sampler: Optional[callable] = None, 97 ignore_label: Optional[int] = None, 98 label_transform: Optional[callable] = None, 99 **loader_kwargs, 100) -> torch.utils.data.DataLoader: 101 """Get a dataloader for supervised segmentation training. 102 103 Args: 104 data_paths: The filepaths to the hdf5 files containing the training data. 105 raw_key: The key that holds the raw data inside of the hdf5. 106 label_key: The key that holds the labels inside of the hdf5. 107 patch_shape: The patch shape used for a training example. 108 In order to run 2d training pass a patch shape with a singleton in the z-axis, 109 e.g. 'patch_shape = [1, 512, 512]'. 110 batch_size: The batch size for training. 111 n_samples: The number of samples per epoch. By default this will be estimated 112 based on the patch_shape and size of the volumes used for training. 113 add_boundary_transform: Whether to add a boundary channel to the training data. 114 label_dtype: The datatype of the labels returned by the dataloader. 115 rois: Optional region of interests for training. 116 sampler: Optional sampler for selecting blocks for training. 117 By default a minimum instance sampler will be used. 118 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 119 ignored in the loss computation. By default this option is not used. 120 label_transform: Label transform that is applied to the segmentation to compute the targets. 121 If no label transform is passed (the default) a boundary transform is used. 122 loader_kwargs: Additional keyword arguments for the dataloader. 123 124 Returns: 125 The PyTorch dataloader. 126 """ 127 _, ndim = _determine_ndim(patch_shape) 128 if label_transform is not None: # A specific label transform was passed, do nothing. 129 pass 130 elif add_boundary_transform: 131 if ignore_label is None: 132 label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True) 133 else: 134 label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel( 135 add_binary_target=True, ignore_label=ignore_label 136 ) 137 138 else: 139 if ignore_label is not None: 140 raise NotImplementedError 141 label_transform = torch_em.transform.label.connected_components 142 143 if ndim == 2: 144 adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape) 145 transform = torch_em.transform.Compose( 146 torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2) 147 ) 148 else: 149 transform = torch_em.transform.Compose( 150 torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3) 151 ) 152 153 num_workers = loader_kwargs.pop("num_workers", 4 * batch_size) 154 shuffle = loader_kwargs.pop("shuffle", True) 155 156 if sampler is None: 157 sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4) 158 159 loader = torch_em.default_segmentation_loader( 160 data_paths, raw_key, 161 data_paths, label_key, sampler=sampler, 162 batch_size=batch_size, patch_shape=patch_shape, ndim=ndim, 163 is_seg_dataset=True, label_transform=label_transform, transform=transform, 164 num_workers=num_workers, shuffle=shuffle, n_samples=n_samples, 165 label_dtype=label_dtype, rois=rois, **loader_kwargs, 166 ) 167 return loader
Get a dataloader for supervised segmentation training.
Arguments:
- data_paths: The filepaths to the hdf5 files containing the training data.
- raw_key: The key that holds the raw data inside of the hdf5.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- batch_size: The batch size for training.
- n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- add_boundary_transform: Whether to add a boundary channel to the training data.
- label_dtype: The datatype of the labels returned by the dataloader.
- rois: Optional region of interests for training.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- loader_kwargs: Additional keyword arguments for the dataloader.
Returns:
The PyTorch dataloader.
def
supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, **loader_kwargs):
170def supervised_training( 171 name: str, 172 train_paths: Tuple[str], 173 val_paths: Tuple[str], 174 label_key: str, 175 patch_shape: Tuple[int, int, int], 176 save_root: Optional[str] = None, 177 raw_key: str = "raw", 178 batch_size: int = 1, 179 lr: float = 1e-4, 180 n_iterations: int = int(1e5), 181 train_rois: Optional[Tuple[Tuple[slice]]] = None, 182 val_rois: Optional[Tuple[Tuple[slice]]] = None, 183 sampler: Optional[callable] = None, 184 n_samples_train: Optional[int] = None, 185 n_samples_val: Optional[int] = None, 186 check: bool = False, 187 ignore_label: Optional[int] = None, 188 label_transform: Optional[callable] = None, 189 in_channels: int = 1, 190 out_channels: int = 2, 191 mask_channel: bool = False, 192 **loader_kwargs, 193): 194 """Run supervised segmentation training. 195 196 This function trains a UNet for predicting outputs for segmentation. 197 Expects instance labels and converts them to boundary targets. 198 This behaviour can be changed by passing custom arguments for `label_transform` 199 and/or `out_channels`. 200 201 Args: 202 name: The name for the checkpoint to be trained. 203 train_paths: Filepaths to the hdf5 files for the training data. 204 val_paths: Filepaths to the df5 files for the validation data. 205 label_key: The key that holds the labels inside of the hdf5. 206 patch_shape: The patch shape used for a training example. 207 In order to run 2d training pass a patch shape with a singleton in the z-axis, 208 e.g. 'patch_shape = [1, 512, 512]'. 209 save_root: Folder where the checkpoint will be saved. 210 raw_key: The key that holds the raw data inside of the hdf5. 211 batch_size: The batch size for training. 212 lr: The initial learning rate. 213 n_iterations: The number of iterations to train for. 214 train_rois: Optional region of interests for training. 215 val_rois: Optional region of interests for validation. 216 sampler: Optional sampler for selecting blocks for training. 217 By default a minimum instance sampler will be used. 218 n_samples_train: The number of train samples per epoch. By default this will be estimated 219 based on the patch_shape and size of the volumes used for training. 220 n_samples_val: The number of val samples per epoch. By default this will be estimated 221 based on the patch_shape and size of the volumes used for validation. 222 check: Whether to check the training and validation loaders instead of running training. 223 ignore_label: Ignore label in the ground-truth. The areas marked by this label will be 224 ignored in the loss computation. By default this option is not used. 225 label_transform: Label transform that is applied to the segmentation to compute the targets. 226 If no label transform is passed (the default) a boundary transform is used. 227 out_channels: The number of output channels of the UNet. 228 mask_channel: Whether the last channels in the labels should be used for masking the loss. 229 This can be used to implement more complex masking operations and is not compatible with `ignore_label`. 230 loader_kwargs: Additional keyword arguments for the dataloader. 231 """ 232 train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size, 233 n_samples=n_samples_train, rois=train_rois, sampler=sampler, 234 ignore_label=ignore_label, label_transform=label_transform, 235 **loader_kwargs) 236 val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size, 237 n_samples=n_samples_val, rois=val_rois, sampler=sampler, 238 ignore_label=ignore_label, label_transform=label_transform, 239 **loader_kwargs) 240 241 if check: 242 from torch_em.util.debug import check_loader 243 check_loader(train_loader, n_samples=4) 244 check_loader(val_loader, n_samples=4) 245 return 246 247 is_2d, _ = _determine_ndim(patch_shape) 248 if is_2d: 249 model = get_2d_model(out_channels=out_channels, in_channels=in_channels) 250 else: 251 model = get_3d_model(out_channels=out_channels, in_channels=in_channels) 252 253 loss, metric = None, None 254 # No ignore label -> we can use default loss. 255 if ignore_label is None and not mask_channel: 256 pass 257 # If we have an ignore label the loss and metric have to be modified 258 # so that the ignore mask is not used in the gradient calculation. 259 elif ignore_label is not None: 260 loss = torch_em.loss.LossWrapper( 261 loss=torch_em.loss.DiceLoss(), 262 transform=torch_em.loss.wrapper.MaskIgnoreLabel( 263 ignore_label=ignore_label, masking_method="multiply", 264 ) 265 ) 266 metric = loss 267 elif mask_channel: 268 loss = torch_em.loss.LossWrapper( 269 loss=torch_em.loss.DiceLoss(), 270 transform=torch_em.loss.wrapper.ApplyAndRemoveMask( 271 masking_method="crop" if out_channels == 1 else "multiply") 272 ) 273 metric = loss 274 else: 275 raise ValueError 276 277 trainer = torch_em.default_segmentation_trainer( 278 name=name, 279 model=model, 280 train_loader=train_loader, 281 val_loader=val_loader, 282 learning_rate=lr, 283 mixed_precision=True, 284 log_image_interval=100, 285 compile_model=False, 286 save_root=save_root, 287 loss=loss, 288 metric=metric, 289 ) 290 trainer.fit(n_iterations)
Run supervised segmentation training.
This function trains a UNet for predicting outputs for segmentation.
Expects instance labels and converts them to boundary targets.
This behaviour can be changed by passing custom arguments for label_transform
and/or out_channels
.
Arguments:
- name: The name for the checkpoint to be trained.
- train_paths: Filepaths to the hdf5 files for the training data.
- val_paths: Filepaths to the df5 files for the validation data.
- label_key: The key that holds the labels inside of the hdf5.
- patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
- save_root: Folder where the checkpoint will be saved.
- raw_key: The key that holds the raw data inside of the hdf5.
- batch_size: The batch size for training.
- lr: The initial learning rate.
- n_iterations: The number of iterations to train for.
- train_rois: Optional region of interests for training.
- val_rois: Optional region of interests for validation.
- sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
- n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
- n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
- check: Whether to check the training and validation loaders instead of running training.
- ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
- label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
- out_channels: The number of output channels of the UNet.
- mask_channel: Whether the last channels in the labels should be used for masking the loss.
This can be used to implement more complex masking operations and is not compatible with
ignore_label
. - loader_kwargs: Additional keyword arguments for the dataloader.