synapse_net.training.supervised_training

  1import os
  2from glob import glob
  3from typing import Optional, Tuple
  4
  5import torch
  6import torch_em
  7from sklearn.model_selection import train_test_split
  8from torch_em.model import AnisotropicUNet, UNet2d
  9
 10from synapse_net.inference.inference import get_model_path, get_available_models
 11
 12
 13def get_3d_model(
 14    out_channels: int,
 15    in_channels: int = 1,
 16    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
 17    initial_features: int = 32,
 18    final_activation: str = "Sigmoid",
 19) -> torch.nn.Module:
 20    """Get the U-Net model for 3D segmentation tasks.
 21
 22    Args:
 23        out_channels: The number of output channels of the network.
 24        scale_factors: The downscaling factors for each level of the U-Net encoder.
 25        initial_features: The number of features in the first level of the U-Net.
 26            The number of features increases by a factor of two in each level.
 27        final_activation: The activation applied to the last output layer.
 28
 29    Returns:
 30        The U-Net.
 31    """
 32    model = AnisotropicUNet(
 33        scale_factors=scale_factors,
 34        in_channels=in_channels,
 35        out_channels=out_channels,
 36        initial_features=initial_features,
 37        gain=2,
 38        final_activation=final_activation,
 39    )
 40    return model
 41
 42
 43def get_2d_model(
 44    out_channels: int,
 45    in_channels: int = 1,
 46    initial_features: int = 32,
 47    final_activation: str = "Sigmoid",
 48) -> torch.nn.Module:
 49    """Get the U-Net model for 2D segmentation tasks.
 50
 51    Args:
 52        out_channels: The number of output channels of the network.
 53        initial_features: The number of features in the first level of the U-Net.
 54            The number of features increases by a factor of two in each level.
 55        final_activation: The activation applied to the last output layer.
 56
 57    Returns:
 58        The U-Net.
 59    """
 60    model = UNet2d(
 61        in_channels=in_channels,
 62        out_channels=out_channels,
 63        initial_features=initial_features,
 64        gain=2,
 65        depth=4,
 66        final_activation=final_activation,
 67    )
 68    return model
 69
 70
 71def _adjust_patch_shape(data_shape, patch_shape):
 72    # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape
 73    if data_shape == 2 and len(patch_shape) == 3:
 74        return patch_shape[1:]  # Remove the leading dimension in patch_shape
 75    return patch_shape  # Return the original patch_shape for 3D data
 76
 77
 78def _determine_ndim(patch_shape):
 79    # Check for 2D or 3D training
 80    try:
 81        z, y, x = patch_shape
 82    except ValueError:
 83        y, x = patch_shape
 84        z = 1
 85    is_2d = z == 1
 86    ndim = 2 if is_2d else 3
 87    return is_2d, ndim
 88
 89
 90def get_supervised_loader(
 91    data_paths: Tuple[str],
 92    raw_key: str,
 93    label_key: str,
 94    patch_shape: Tuple[int, int, int],
 95    batch_size: int,
 96    n_samples: Optional[int],
 97    add_boundary_transform: bool = True,
 98    label_dtype=torch.float32,
 99    rois: Optional[Tuple[Tuple[slice]]] = None,
100    sampler: Optional[callable] = None,
101    ignore_label: Optional[int] = None,
102    label_transform: Optional[callable] = None,
103    label_paths: Optional[Tuple[str]] = None,
104    **loader_kwargs,
105) -> torch.utils.data.DataLoader:
106    """Get a dataloader for supervised segmentation training.
107
108    Args:
109        data_paths: The filepaths to the hdf5 files containing the training data.
110        raw_key: The key that holds the raw data inside of the hdf5.
111        label_key: The key that holds the labels inside of the hdf5.
112        patch_shape: The patch shape used for a training example.
113            In order to run 2d training pass a patch shape with a singleton in the z-axis,
114            e.g. 'patch_shape = [1, 512, 512]'.
115        batch_size: The batch size for training.
116        n_samples: The number of samples per epoch. By default this will be estimated
117            based on the patch_shape and size of the volumes used for training.
118        add_boundary_transform: Whether to add a boundary channel to the training data.
119        label_dtype: The datatype of the labels returned by the dataloader.
120        rois: Optional region of interests for training.
121        sampler: Optional sampler for selecting blocks for training.
122            By default a minimum instance sampler will be used.
123        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
124            ignored in the loss computation. By default this option is not used.
125        label_transform: Label transform that is applied to the segmentation to compute the targets.
126            If no label transform is passed (the default) a boundary transform is used.
127        label_paths: Optional paths containing the labels / annotations for training.
128            If not given, the labels are expected to be contained in the `data_paths`.
129        loader_kwargs: Additional keyword arguments for the dataloader.
130
131    Returns:
132        The PyTorch dataloader.
133    """
134    _, ndim = _determine_ndim(patch_shape)
135    if label_transform is not None:  # A specific label transform was passed, do nothing.
136        pass
137    elif add_boundary_transform:
138        if ignore_label is None:
139            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
140        else:
141            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
142                add_binary_target=True, ignore_label=ignore_label
143            )
144
145    else:
146        if ignore_label is not None:
147            raise NotImplementedError
148        label_transform = torch_em.transform.label.connected_components
149
150    if ndim == 2:
151        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
152        transform = torch_em.transform.Compose(
153            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
154        )
155    else:
156        transform = torch_em.transform.Compose(
157            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
158        )
159
160    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
161    shuffle = loader_kwargs.pop("shuffle", True)
162
163    if sampler is None:
164        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
165
166    if label_paths is None:
167        label_paths = data_paths
168    elif len(label_paths) != len(data_paths):
169        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
170
171    loader = torch_em.default_segmentation_loader(
172        data_paths, raw_key,
173        label_paths, label_key, sampler=sampler,
174        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
175        is_seg_dataset=True, label_transform=label_transform, transform=transform,
176        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
177        label_dtype=label_dtype, rois=rois, **loader_kwargs,
178    )
179    return loader
180
181
182def supervised_training(
183    name: str,
184    train_paths: Tuple[str],
185    val_paths: Tuple[str],
186    label_key: str,
187    patch_shape: Tuple[int, int, int],
188    save_root: Optional[str] = None,
189    raw_key: str = "raw",
190    batch_size: int = 1,
191    lr: float = 1e-4,
192    n_iterations: int = int(1e5),
193    train_label_paths: Optional[Tuple[str]] = None,
194    val_label_paths: Optional[Tuple[str]] = None,
195    train_rois: Optional[Tuple[Tuple[slice]]] = None,
196    val_rois: Optional[Tuple[Tuple[slice]]] = None,
197    sampler: Optional[callable] = None,
198    n_samples_train: Optional[int] = None,
199    n_samples_val: Optional[int] = None,
200    check: bool = False,
201    ignore_label: Optional[int] = None,
202    label_transform: Optional[callable] = None,
203    loss_fn: Optional[torch.nn.Module] = None,
204    in_channels: int = 1,
205    out_channels: int = 2,
206    mask_channel: bool = False,
207    checkpoint_path: Optional[str] = None,
208    **loader_kwargs,
209):
210    """Run supervised segmentation training.
211
212    This function trains a UNet for predicting outputs for segmentation.
213    Expects instance labels and converts them to boundary targets.
214    This behaviour can be changed by passing custom arguments for `label_transform`
215    and/or `out_channels`.
216
217    Args:
218        name: The name for the checkpoint to be trained.
219        train_paths: Filepaths to the hdf5 files for the training data.
220        val_paths: Filepaths to the df5 files for the validation data.
221        label_key: The key that holds the labels inside of the hdf5.
222        patch_shape: The patch shape used for a training example.
223            In order to run 2d training pass a patch shape with a singleton in the z-axis,
224            e.g. 'patch_shape = [1, 512, 512]'.
225        save_root: Folder where the checkpoint will be saved.
226        raw_key: The key that holds the raw data inside of the hdf5.
227        batch_size: The batch size for training.
228        lr: The initial learning rate.
229        n_iterations: The number of iterations to train for.
230        train_label_paths: Optional paths containing the label data for training.
231            If not given, the labels are expected to be part of `train_paths`.
232        val_label_paths: Optional paths containing the label data for validation.
233            If not given, the labels are expected to be part of `val_paths`.
234        train_rois: Optional region of interests for training.
235        val_rois: Optional region of interests for validation.
236        sampler: Optional sampler for selecting blocks for training.
237            By default a minimum instance sampler will be used.
238        n_samples_train: The number of train samples per epoch. By default this will be estimated
239            based on the patch_shape and size of the volumes used for training.
240        n_samples_val: The number of val samples per epoch. By default this will be estimated
241            based on the patch_shape and size of the volumes used for validation.
242        check: Whether to check the training and validation loaders instead of running training.
243        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
244            ignored in the loss computation. By default this option is not used.
245        label_transform: Label transform that is applied to the segmentation to compute the targets.
246            If no label transform is passed (the default) a boundary transform is used.
247        loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`.
248        out_channels: The number of output channels of the UNet.
249        mask_channel: Whether the last channels in the labels should be used for masking the loss.
250            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
251        checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
252        loader_kwargs: Additional keyword arguments for the dataloader.
253    """
254    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
255                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
256                                         ignore_label=ignore_label, label_transform=label_transform,
257                                         label_paths=train_label_paths, **loader_kwargs)
258    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
259                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
260                                       ignore_label=ignore_label, label_transform=label_transform,
261                                       label_paths=val_label_paths, **loader_kwargs)
262
263    if check:
264        from torch_em.util.debug import check_loader
265        check_loader(train_loader, n_samples=4)
266        check_loader(val_loader, n_samples=4)
267        return
268
269    is_2d, _ = _determine_ndim(patch_shape)
270    if checkpoint_path is not None:
271        model = torch_em.util.load_model(checkpoint=checkpoint_path)
272    elif is_2d:
273        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
274    else:
275        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
276
277    base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss()
278    metric = base_loss
279
280    # No ignore label -> we can use default loss.
281    if ignore_label is None and not mask_channel:
282        loss = base_loss
283        
284    # If we have an ignore label the loss and metric have to be modified
285    # so that the ignore mask is not used in the gradient calculation.
286    elif ignore_label is not None:
287        loss = torch_em.loss.LossWrapper(
288            loss=base_loss,
289            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
290                ignore_label=ignore_label, masking_method="multiply",
291            )
292        )
293        metric = loss
294    elif mask_channel:
295        loss = torch_em.loss.LossWrapper(
296            loss=base_loss,
297            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
298                masking_method="crop" if out_channels == 1 else "multiply")
299        )
300        metric = loss
301    else:
302        #raise ValueError
303        loss = base_loss
304        metric = loss
305        
306
307    trainer = torch_em.default_segmentation_trainer(
308        name=name,
309        model=model,
310        train_loader=train_loader,
311        val_loader=val_loader,
312        learning_rate=lr,
313        mixed_precision=True,
314        log_image_interval=100,
315        compile_model=False,
316        save_root=save_root,
317        loss=loss,
318        metric=metric,
319    )
320    trainer.fit(n_iterations)
321
322
323def _derive_key_from_files(files, key):
324    # Get all file extensions (general wild-cards may pick up files with multiple extensions).
325    extensions = list(set([os.path.splitext(ff)[1] for ff in files]))
326
327    # If we have more than 1 file extension we just use the key that was passed,
328    # as it is unclear how to derive a consistent key.
329    if len(extensions) > 1:
330        return files, key
331
332    ext = extensions[0]
333    extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"}
334
335    # Derive the key from the extension if the key is None.
336    if key is None and ext in extension_to_key:
337        key = extension_to_key[ext]
338    # If the key is None and can't be derived raise an error.
339    elif key is None and ext not in extension_to_key:
340        raise ValueError(
341            f"You have not passed a key for the data in {ext} format, for which the key cannot be derived."
342        )
343    # If the key was passed and doesn't match the extension raise an error.
344    elif key is not None and ext in extension_to_key and key != extension_to_key[ext]:
345        raise ValueError(
346            f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}."
347        )
348    return files, key
349
350
351def _parse_input_folder(folder, pattern, key):
352    files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
353    return _derive_key_from_files(files, key)
354
355
356def _parse_input_files(args):
357    train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key)
358    train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key)
359    if len(train_image_paths) != len(train_label_paths):
360        raise ValueError(
361            f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match."
362            f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}."
363        )
364
365    if args.val_folder is None:
366        if args.val_label_folder is not None:
367            raise ValueError("You have passed a val_label_folder, but not a val_folder.")
368        train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(
369            train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42
370        )
371    else:
372        if args.val_label_folder is None:
373            raise ValueError("You have passed a val_folder, but not a val_label_folder.")
374        val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key)
375        val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)
376
377    return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key
378
379
380def _parse_checkpoint(initial_model):
381    if initial_model is None:
382        return None
383    if os.path.exists(initial_model):
384        return initial_model
385    model_path = get_model_path(initial_model)
386    return model_path
387
388
389def main():
390    """@private
391    """
392    import argparse
393
394    parser = argparse.ArgumentParser(
395        description="Train a model for foreground and boundary segmentation via supervised learning.\n\n"
396        "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n"  # noqa
397        "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n"  # noqa
398        "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)."  # noqa
399        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
400        "Check out the information below for details on the arguments of this function.",
401        formatter_class=argparse.RawTextHelpFormatter
402    )
403    parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.")
404    parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")
405
406    # Folders with training data, containing raw/image data and labels.
407    parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.")
408    parser.add_argument("--image_file_pattern", default="*",
409                        help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.")
410    parser.add_argument("--raw_key",
411                        help="The internal path for the raw data. If not given, will be determined based on the file extension.")  # noqa
412    parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.")
413    parser.add_argument("--label_file_pattern", default="*",
414                        help="The pattern for selecting label files. For example, '*.tif' to select all tif files.")
415    parser.add_argument("--label_key",
416                        help="The internal path for the label data. If not given, will be determined based on the file extension.")  # noqa
417
418    # Optional folders with validation data. If not given the training data is split into train/val.
419    parser.add_argument("--val_folder",
420                        help="The input folder with the validation data. If not given the training data will be split for validation")  # noqa
421    parser.add_argument("--val_label_folder",
422                        help="The input folder with the validation labels. If not given the training data will be split for validation.")  # noqa
423
424    # Optional: choose a model for initializing the weights.
425    available_models = get_available_models()
426    parser.add_argument(
427        "--initial_model",
428        help="Choose a model checkpoint for weight initialization.\n"
429        "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n"
430        f"The following pretrained models are available: {available_models}.\n"
431        "If not given, the model will be randomly initialized."
432    )
433
434    # More optional argument:
435    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
436    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
437    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
438    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
439    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
440    parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.")  # noqa
441    parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")
442    args = parser.parse_args()
443
444    train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
445        _parse_input_files(args)
446    checkpoint_path = _parse_checkpoint(args.initial_model)
447
448    supervised_training(
449        name=args.name, train_paths=train_image_paths, val_paths=val_image_paths,
450        train_label_paths=train_label_paths, val_label_paths=val_label_paths,
451        raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
452        n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
453        check=args.check, n_iterations=args.n_iterations, save_root=args.save_root,
454        checkpoint_path=checkpoint_path
455    )
def get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
14def get_3d_model(
15    out_channels: int,
16    in_channels: int = 1,
17    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
18    initial_features: int = 32,
19    final_activation: str = "Sigmoid",
20) -> torch.nn.Module:
21    """Get the U-Net model for 3D segmentation tasks.
22
23    Args:
24        out_channels: The number of output channels of the network.
25        scale_factors: The downscaling factors for each level of the U-Net encoder.
26        initial_features: The number of features in the first level of the U-Net.
27            The number of features increases by a factor of two in each level.
28        final_activation: The activation applied to the last output layer.
29
30    Returns:
31        The U-Net.
32    """
33    model = AnisotropicUNet(
34        scale_factors=scale_factors,
35        in_channels=in_channels,
36        out_channels=out_channels,
37        initial_features=initial_features,
38        gain=2,
39        final_activation=final_activation,
40    )
41    return model

Get the U-Net model for 3D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • scale_factors: The downscaling factors for each level of the U-Net encoder.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
44def get_2d_model(
45    out_channels: int,
46    in_channels: int = 1,
47    initial_features: int = 32,
48    final_activation: str = "Sigmoid",
49) -> torch.nn.Module:
50    """Get the U-Net model for 2D segmentation tasks.
51
52    Args:
53        out_channels: The number of output channels of the network.
54        initial_features: The number of features in the first level of the U-Net.
55            The number of features increases by a factor of two in each level.
56        final_activation: The activation applied to the last output layer.
57
58    Returns:
59        The U-Net.
60    """
61    model = UNet2d(
62        in_channels=in_channels,
63        out_channels=out_channels,
64        initial_features=initial_features,
65        gain=2,
66        depth=4,
67        final_activation=final_activation,
68    )
69    return model

Get the U-Net model for 2D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
 91def get_supervised_loader(
 92    data_paths: Tuple[str],
 93    raw_key: str,
 94    label_key: str,
 95    patch_shape: Tuple[int, int, int],
 96    batch_size: int,
 97    n_samples: Optional[int],
 98    add_boundary_transform: bool = True,
 99    label_dtype=torch.float32,
100    rois: Optional[Tuple[Tuple[slice]]] = None,
101    sampler: Optional[callable] = None,
102    ignore_label: Optional[int] = None,
103    label_transform: Optional[callable] = None,
104    label_paths: Optional[Tuple[str]] = None,
105    **loader_kwargs,
106) -> torch.utils.data.DataLoader:
107    """Get a dataloader for supervised segmentation training.
108
109    Args:
110        data_paths: The filepaths to the hdf5 files containing the training data.
111        raw_key: The key that holds the raw data inside of the hdf5.
112        label_key: The key that holds the labels inside of the hdf5.
113        patch_shape: The patch shape used for a training example.
114            In order to run 2d training pass a patch shape with a singleton in the z-axis,
115            e.g. 'patch_shape = [1, 512, 512]'.
116        batch_size: The batch size for training.
117        n_samples: The number of samples per epoch. By default this will be estimated
118            based on the patch_shape and size of the volumes used for training.
119        add_boundary_transform: Whether to add a boundary channel to the training data.
120        label_dtype: The datatype of the labels returned by the dataloader.
121        rois: Optional region of interests for training.
122        sampler: Optional sampler for selecting blocks for training.
123            By default a minimum instance sampler will be used.
124        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
125            ignored in the loss computation. By default this option is not used.
126        label_transform: Label transform that is applied to the segmentation to compute the targets.
127            If no label transform is passed (the default) a boundary transform is used.
128        label_paths: Optional paths containing the labels / annotations for training.
129            If not given, the labels are expected to be contained in the `data_paths`.
130        loader_kwargs: Additional keyword arguments for the dataloader.
131
132    Returns:
133        The PyTorch dataloader.
134    """
135    _, ndim = _determine_ndim(patch_shape)
136    if label_transform is not None:  # A specific label transform was passed, do nothing.
137        pass
138    elif add_boundary_transform:
139        if ignore_label is None:
140            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
141        else:
142            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
143                add_binary_target=True, ignore_label=ignore_label
144            )
145
146    else:
147        if ignore_label is not None:
148            raise NotImplementedError
149        label_transform = torch_em.transform.label.connected_components
150
151    if ndim == 2:
152        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
153        transform = torch_em.transform.Compose(
154            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
155        )
156    else:
157        transform = torch_em.transform.Compose(
158            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
159        )
160
161    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
162    shuffle = loader_kwargs.pop("shuffle", True)
163
164    if sampler is None:
165        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
166
167    if label_paths is None:
168        label_paths = data_paths
169    elif len(label_paths) != len(data_paths):
170        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
171
172    loader = torch_em.default_segmentation_loader(
173        data_paths, raw_key,
174        label_paths, label_key, sampler=sampler,
175        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
176        is_seg_dataset=True, label_transform=label_transform, transform=transform,
177        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
178        label_dtype=label_dtype, rois=rois, **loader_kwargs,
179    )
180    return loader

Get a dataloader for supervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • add_boundary_transform: Whether to add a boundary channel to the training data.
  • label_dtype: The datatype of the labels returned by the dataloader.
  • rois: Optional region of interests for training.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • label_paths: Optional paths containing the labels / annotations for training. If not given, the labels are expected to be contained in the data_paths.
  • loader_kwargs: Additional keyword arguments for the dataloader.
Returns:

The PyTorch dataloader.

def supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, loss_fn: Optional[torch.nn.modules.module.Module] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, **loader_kwargs):
183def supervised_training(
184    name: str,
185    train_paths: Tuple[str],
186    val_paths: Tuple[str],
187    label_key: str,
188    patch_shape: Tuple[int, int, int],
189    save_root: Optional[str] = None,
190    raw_key: str = "raw",
191    batch_size: int = 1,
192    lr: float = 1e-4,
193    n_iterations: int = int(1e5),
194    train_label_paths: Optional[Tuple[str]] = None,
195    val_label_paths: Optional[Tuple[str]] = None,
196    train_rois: Optional[Tuple[Tuple[slice]]] = None,
197    val_rois: Optional[Tuple[Tuple[slice]]] = None,
198    sampler: Optional[callable] = None,
199    n_samples_train: Optional[int] = None,
200    n_samples_val: Optional[int] = None,
201    check: bool = False,
202    ignore_label: Optional[int] = None,
203    label_transform: Optional[callable] = None,
204    loss_fn: Optional[torch.nn.Module] = None,
205    in_channels: int = 1,
206    out_channels: int = 2,
207    mask_channel: bool = False,
208    checkpoint_path: Optional[str] = None,
209    **loader_kwargs,
210):
211    """Run supervised segmentation training.
212
213    This function trains a UNet for predicting outputs for segmentation.
214    Expects instance labels and converts them to boundary targets.
215    This behaviour can be changed by passing custom arguments for `label_transform`
216    and/or `out_channels`.
217
218    Args:
219        name: The name for the checkpoint to be trained.
220        train_paths: Filepaths to the hdf5 files for the training data.
221        val_paths: Filepaths to the df5 files for the validation data.
222        label_key: The key that holds the labels inside of the hdf5.
223        patch_shape: The patch shape used for a training example.
224            In order to run 2d training pass a patch shape with a singleton in the z-axis,
225            e.g. 'patch_shape = [1, 512, 512]'.
226        save_root: Folder where the checkpoint will be saved.
227        raw_key: The key that holds the raw data inside of the hdf5.
228        batch_size: The batch size for training.
229        lr: The initial learning rate.
230        n_iterations: The number of iterations to train for.
231        train_label_paths: Optional paths containing the label data for training.
232            If not given, the labels are expected to be part of `train_paths`.
233        val_label_paths: Optional paths containing the label data for validation.
234            If not given, the labels are expected to be part of `val_paths`.
235        train_rois: Optional region of interests for training.
236        val_rois: Optional region of interests for validation.
237        sampler: Optional sampler for selecting blocks for training.
238            By default a minimum instance sampler will be used.
239        n_samples_train: The number of train samples per epoch. By default this will be estimated
240            based on the patch_shape and size of the volumes used for training.
241        n_samples_val: The number of val samples per epoch. By default this will be estimated
242            based on the patch_shape and size of the volumes used for validation.
243        check: Whether to check the training and validation loaders instead of running training.
244        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
245            ignored in the loss computation. By default this option is not used.
246        label_transform: Label transform that is applied to the segmentation to compute the targets.
247            If no label transform is passed (the default) a boundary transform is used.
248        loss_fn: Custom loss function. If None, will default to `torch_em.loss.DiceLoss`.
249        out_channels: The number of output channels of the UNet.
250        mask_channel: Whether the last channels in the labels should be used for masking the loss.
251            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
252        checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
253        loader_kwargs: Additional keyword arguments for the dataloader.
254    """
255    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
256                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
257                                         ignore_label=ignore_label, label_transform=label_transform,
258                                         label_paths=train_label_paths, **loader_kwargs)
259    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
260                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
261                                       ignore_label=ignore_label, label_transform=label_transform,
262                                       label_paths=val_label_paths, **loader_kwargs)
263
264    if check:
265        from torch_em.util.debug import check_loader
266        check_loader(train_loader, n_samples=4)
267        check_loader(val_loader, n_samples=4)
268        return
269
270    is_2d, _ = _determine_ndim(patch_shape)
271    if checkpoint_path is not None:
272        model = torch_em.util.load_model(checkpoint=checkpoint_path)
273    elif is_2d:
274        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
275    else:
276        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
277
278    base_loss = loss_fn if loss_fn is not None else torch_em.loss.DiceLoss()
279    metric = base_loss
280
281    # No ignore label -> we can use default loss.
282    if ignore_label is None and not mask_channel:
283        loss = base_loss
284        
285    # If we have an ignore label the loss and metric have to be modified
286    # so that the ignore mask is not used in the gradient calculation.
287    elif ignore_label is not None:
288        loss = torch_em.loss.LossWrapper(
289            loss=base_loss,
290            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
291                ignore_label=ignore_label, masking_method="multiply",
292            )
293        )
294        metric = loss
295    elif mask_channel:
296        loss = torch_em.loss.LossWrapper(
297            loss=base_loss,
298            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
299                masking_method="crop" if out_channels == 1 else "multiply")
300        )
301        metric = loss
302    else:
303        #raise ValueError
304        loss = base_loss
305        metric = loss
306        
307
308    trainer = torch_em.default_segmentation_trainer(
309        name=name,
310        model=model,
311        train_loader=train_loader,
312        val_loader=val_loader,
313        learning_rate=lr,
314        mixed_precision=True,
315        log_image_interval=100,
316        compile_model=False,
317        save_root=save_root,
318        loss=loss,
319        metric=metric,
320    )
321    trainer.fit(n_iterations)

Run supervised segmentation training.

This function trains a UNet for predicting outputs for segmentation. Expects instance labels and converts them to boundary targets. This behaviour can be changed by passing custom arguments for label_transform and/or out_channels.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • train_label_paths: Optional paths containing the label data for training. If not given, the labels are expected to be part of train_paths.
  • val_label_paths: Optional paths containing the label data for validation. If not given, the labels are expected to be part of val_paths.
  • train_rois: Optional region of interests for training.
  • val_rois: Optional region of interests for validation.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • loss_fn: Custom loss function. If None, will default to torch_em.loss.DiceLoss.
  • out_channels: The number of output channels of the UNet.
  • mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with ignore_label.
  • checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
  • loader_kwargs: Additional keyword arguments for the dataloader.