synapse_net.training.supervised_training

  1import os
  2from glob import glob
  3from typing import Optional, Tuple
  4
  5import torch
  6import torch_em
  7from sklearn.model_selection import train_test_split
  8from torch_em.model import AnisotropicUNet, UNet2d
  9
 10from synapse_net.inference.inference import get_model_path, get_available_models
 11
 12
 13def get_3d_model(
 14    out_channels: int,
 15    in_channels: int = 1,
 16    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
 17    initial_features: int = 32,
 18    final_activation: str = "Sigmoid",
 19) -> torch.nn.Module:
 20    """Get the U-Net model for 3D segmentation tasks.
 21
 22    Args:
 23        out_channels: The number of output channels of the network.
 24        scale_factors: The downscaling factors for each level of the U-Net encoder.
 25        initial_features: The number of features in the first level of the U-Net.
 26            The number of features increases by a factor of two in each level.
 27        final_activation: The activation applied to the last output layer.
 28
 29    Returns:
 30        The U-Net.
 31    """
 32    model = AnisotropicUNet(
 33        scale_factors=scale_factors,
 34        in_channels=in_channels,
 35        out_channels=out_channels,
 36        initial_features=initial_features,
 37        gain=2,
 38        final_activation=final_activation,
 39    )
 40    return model
 41
 42
 43def get_2d_model(
 44    out_channels: int,
 45    in_channels: int = 1,
 46    initial_features: int = 32,
 47    final_activation: str = "Sigmoid",
 48) -> torch.nn.Module:
 49    """Get the U-Net model for 2D segmentation tasks.
 50
 51    Args:
 52        out_channels: The number of output channels of the network.
 53        initial_features: The number of features in the first level of the U-Net.
 54            The number of features increases by a factor of two in each level.
 55        final_activation: The activation applied to the last output layer.
 56
 57    Returns:
 58        The U-Net.
 59    """
 60    model = UNet2d(
 61        in_channels=in_channels,
 62        out_channels=out_channels,
 63        initial_features=initial_features,
 64        gain=2,
 65        depth=4,
 66        final_activation=final_activation,
 67    )
 68    return model
 69
 70
 71def _adjust_patch_shape(data_shape, patch_shape):
 72    # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape
 73    if data_shape == 2 and len(patch_shape) == 3:
 74        return patch_shape[1:]  # Remove the leading dimension in patch_shape
 75    return patch_shape  # Return the original patch_shape for 3D data
 76
 77
 78def _determine_ndim(patch_shape):
 79    # Check for 2D or 3D training
 80    try:
 81        z, y, x = patch_shape
 82    except ValueError:
 83        y, x = patch_shape
 84        z = 1
 85    is_2d = z == 1
 86    ndim = 2 if is_2d else 3
 87    return is_2d, ndim
 88
 89
 90def get_supervised_loader(
 91    data_paths: Tuple[str],
 92    raw_key: str,
 93    label_key: str,
 94    patch_shape: Tuple[int, int, int],
 95    batch_size: int,
 96    n_samples: Optional[int],
 97    add_boundary_transform: bool = True,
 98    label_dtype=torch.float32,
 99    rois: Optional[Tuple[Tuple[slice]]] = None,
100    sampler: Optional[callable] = None,
101    ignore_label: Optional[int] = None,
102    label_transform: Optional[callable] = None,
103    label_paths: Optional[Tuple[str]] = None,
104    **loader_kwargs,
105) -> torch.utils.data.DataLoader:
106    """Get a dataloader for supervised segmentation training.
107
108    Args:
109        data_paths: The filepaths to the hdf5 files containing the training data.
110        raw_key: The key that holds the raw data inside of the hdf5.
111        label_key: The key that holds the labels inside of the hdf5.
112        patch_shape: The patch shape used for a training example.
113            In order to run 2d training pass a patch shape with a singleton in the z-axis,
114            e.g. 'patch_shape = [1, 512, 512]'.
115        batch_size: The batch size for training.
116        n_samples: The number of samples per epoch. By default this will be estimated
117            based on the patch_shape and size of the volumes used for training.
118        add_boundary_transform: Whether to add a boundary channel to the training data.
119        label_dtype: The datatype of the labels returned by the dataloader.
120        rois: Optional region of interests for training.
121        sampler: Optional sampler for selecting blocks for training.
122            By default a minimum instance sampler will be used.
123        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
124            ignored in the loss computation. By default this option is not used.
125        label_transform: Label transform that is applied to the segmentation to compute the targets.
126            If no label transform is passed (the default) a boundary transform is used.
127        label_paths: Optional paths containing the labels / annotations for training.
128            If not given, the labels are expected to be contained in the `data_paths`.
129        loader_kwargs: Additional keyword arguments for the dataloader.
130
131    Returns:
132        The PyTorch dataloader.
133    """
134    _, ndim = _determine_ndim(patch_shape)
135    if label_transform is not None:  # A specific label transform was passed, do nothing.
136        pass
137    elif add_boundary_transform:
138        if ignore_label is None:
139            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
140        else:
141            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
142                add_binary_target=True, ignore_label=ignore_label
143            )
144
145    else:
146        if ignore_label is not None:
147            raise NotImplementedError
148        label_transform = torch_em.transform.label.connected_components
149
150    if ndim == 2:
151        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
152        transform = torch_em.transform.Compose(
153            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
154        )
155    else:
156        transform = torch_em.transform.Compose(
157            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
158        )
159
160    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
161    shuffle = loader_kwargs.pop("shuffle", True)
162
163    if sampler is None:
164        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
165
166    if label_paths is None:
167        label_paths = data_paths
168    elif len(label_paths) != len(data_paths):
169        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
170
171    loader = torch_em.default_segmentation_loader(
172        data_paths, raw_key,
173        label_paths, label_key, sampler=sampler,
174        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
175        is_seg_dataset=True, label_transform=label_transform, transform=transform,
176        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
177        label_dtype=label_dtype, rois=rois, **loader_kwargs,
178    )
179    return loader
180
181
182def supervised_training(
183    name: str,
184    train_paths: Tuple[str],
185    val_paths: Tuple[str],
186    label_key: str,
187    patch_shape: Tuple[int, int, int],
188    save_root: Optional[str] = None,
189    raw_key: str = "raw",
190    batch_size: int = 1,
191    lr: float = 1e-4,
192    n_iterations: int = int(1e5),
193    train_label_paths: Optional[Tuple[str]] = None,
194    val_label_paths: Optional[Tuple[str]] = None,
195    train_rois: Optional[Tuple[Tuple[slice]]] = None,
196    val_rois: Optional[Tuple[Tuple[slice]]] = None,
197    sampler: Optional[callable] = None,
198    n_samples_train: Optional[int] = None,
199    n_samples_val: Optional[int] = None,
200    check: bool = False,
201    ignore_label: Optional[int] = None,
202    label_transform: Optional[callable] = None,
203    in_channels: int = 1,
204    out_channels: int = 2,
205    mask_channel: bool = False,
206    checkpoint_path: Optional[str] = None,
207    **loader_kwargs,
208):
209    """Run supervised segmentation training.
210
211    This function trains a UNet for predicting outputs for segmentation.
212    Expects instance labels and converts them to boundary targets.
213    This behaviour can be changed by passing custom arguments for `label_transform`
214    and/or `out_channels`.
215
216    Args:
217        name: The name for the checkpoint to be trained.
218        train_paths: Filepaths to the hdf5 files for the training data.
219        val_paths: Filepaths to the df5 files for the validation data.
220        label_key: The key that holds the labels inside of the hdf5.
221        patch_shape: The patch shape used for a training example.
222            In order to run 2d training pass a patch shape with a singleton in the z-axis,
223            e.g. 'patch_shape = [1, 512, 512]'.
224        save_root: Folder where the checkpoint will be saved.
225        raw_key: The key that holds the raw data inside of the hdf5.
226        batch_size: The batch size for training.
227        lr: The initial learning rate.
228        n_iterations: The number of iterations to train for.
229        train_label_paths: Optional paths containing the label data for training.
230            If not given, the labels are expected to be part of `train_paths`.
231        val_label_paths: Optional paths containing the label data for validation.
232            If not given, the labels are expected to be part of `val_paths`.
233        train_rois: Optional region of interests for training.
234        val_rois: Optional region of interests for validation.
235        sampler: Optional sampler for selecting blocks for training.
236            By default a minimum instance sampler will be used.
237        n_samples_train: The number of train samples per epoch. By default this will be estimated
238            based on the patch_shape and size of the volumes used for training.
239        n_samples_val: The number of val samples per epoch. By default this will be estimated
240            based on the patch_shape and size of the volumes used for validation.
241        check: Whether to check the training and validation loaders instead of running training.
242        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
243            ignored in the loss computation. By default this option is not used.
244        label_transform: Label transform that is applied to the segmentation to compute the targets.
245            If no label transform is passed (the default) a boundary transform is used.
246        out_channels: The number of output channels of the UNet.
247        mask_channel: Whether the last channels in the labels should be used for masking the loss.
248            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
249        checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
250        loader_kwargs: Additional keyword arguments for the dataloader.
251    """
252    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
253                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
254                                         ignore_label=ignore_label, label_transform=label_transform,
255                                         label_paths=train_label_paths, **loader_kwargs)
256    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
257                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
258                                       ignore_label=ignore_label, label_transform=label_transform,
259                                       label_paths=val_label_paths, **loader_kwargs)
260
261    if check:
262        from torch_em.util.debug import check_loader
263        check_loader(train_loader, n_samples=4)
264        check_loader(val_loader, n_samples=4)
265        return
266
267    is_2d, _ = _determine_ndim(patch_shape)
268    if checkpoint_path is not None:
269        model = torch_em.util.load_model(checkpoint=checkpoint_path)
270    elif is_2d:
271        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
272    else:
273        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
274
275    loss, metric = None, None
276    # No ignore label -> we can use default loss.
277    if ignore_label is None and not mask_channel:
278        pass
279    # If we have an ignore label the loss and metric have to be modified
280    # so that the ignore mask is not used in the gradient calculation.
281    elif ignore_label is not None:
282        loss = torch_em.loss.LossWrapper(
283            loss=torch_em.loss.DiceLoss(),
284            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
285                ignore_label=ignore_label, masking_method="multiply",
286            )
287        )
288        metric = loss
289    elif mask_channel:
290        loss = torch_em.loss.LossWrapper(
291            loss=torch_em.loss.DiceLoss(),
292            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
293                masking_method="crop" if out_channels == 1 else "multiply")
294        )
295        metric = loss
296    else:
297        raise ValueError
298
299    trainer = torch_em.default_segmentation_trainer(
300        name=name,
301        model=model,
302        train_loader=train_loader,
303        val_loader=val_loader,
304        learning_rate=lr,
305        mixed_precision=True,
306        log_image_interval=100,
307        compile_model=False,
308        save_root=save_root,
309        loss=loss,
310        metric=metric,
311    )
312    trainer.fit(n_iterations)
313
314
315def _derive_key_from_files(files, key):
316    # Get all file extensions (general wild-cards may pick up files with multiple extensions).
317    extensions = list(set([os.path.splitext(ff)[1] for ff in files]))
318
319    # If we have more than 1 file extension we just use the key that was passed,
320    # as it is unclear how to derive a consistent key.
321    if len(extensions) > 1:
322        return files, key
323
324    ext = extensions[0]
325    extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"}
326
327    # Derive the key from the extension if the key is None.
328    if key is None and ext in extension_to_key:
329        key = extension_to_key[ext]
330    # If the key is None and can't be derived raise an error.
331    elif key is None and ext not in extension_to_key:
332        raise ValueError(
333            f"You have not passed a key for the data in {ext} format, for which the key cannot be derived."
334        )
335    # If the key was passed and doesn't match the extension raise an error.
336    elif key is not None and ext in extension_to_key and key != extension_to_key[ext]:
337        raise ValueError(
338            f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}."
339        )
340    return files, key
341
342
343def _parse_input_folder(folder, pattern, key):
344    files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
345    return _derive_key_from_files(files, key)
346
347
348def _parse_input_files(args):
349    train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key)
350    train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key)
351    if len(train_image_paths) != len(train_label_paths):
352        raise ValueError(
353            f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match."
354            f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}."
355        )
356
357    if args.val_folder is None:
358        if args.val_label_folder is not None:
359            raise ValueError("You have passed a val_label_folder, but not a val_folder.")
360        train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(
361            train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42
362        )
363    else:
364        if args.val_label_folder is None:
365            raise ValueError("You have passed a val_folder, but not a val_label_folder.")
366        val_image_paths, _ = _parse_input_folder(args.val_folder, args.image_file_pattern, raw_key)
367        val_label_paths, _ = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)
368
369    return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key
370
371
372def _parse_checkpoint(initial_model):
373    if initial_model is None:
374        return None
375    if os.path.exists(initial_model):
376        return initial_model
377    model_path = get_model_path(initial_model)
378    return model_path
379
380
381def main():
382    """@private
383    """
384    import argparse
385
386    parser = argparse.ArgumentParser(
387        description="Train a model for foreground and boundary segmentation via supervised learning.\n\n"
388        "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n"  # noqa
389        "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n"  # noqa
390        "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)."  # noqa
391        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
392        "Check out the information below for details on the arguments of this function.",
393        formatter_class=argparse.RawTextHelpFormatter
394    )
395    parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.")
396    parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")
397
398    # Folders with training data, containing raw/image data and labels.
399    parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.")
400    parser.add_argument("--image_file_pattern", default="*",
401                        help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.")
402    parser.add_argument("--raw_key",
403                        help="The internal path for the raw data. If not given, will be determined based on the file extension.")  # noqa
404    parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.")
405    parser.add_argument("--label_file_pattern", default="*",
406                        help="The pattern for selecting label files. For example, '*.tif' to select all tif files.")
407    parser.add_argument("--label_key",
408                        help="The internal path for the label data. If not given, will be determined based on the file extension.")  # noqa
409
410    # Optional folders with validation data. If not given the training data is split into train/val.
411    parser.add_argument("--val_folder",
412                        help="The input folder with the validation data. If not given the training data will be split for validation")  # noqa
413    parser.add_argument("--val_label_folder",
414                        help="The input folder with the validation labels. If not given the training data will be split for validation.")  # noqa
415
416    # Optional: choose a model for initializing the weights.
417    available_models = get_available_models()
418    parser.add_argument(
419        "--initial_model",
420        help="Choose a model checkpoint for weight initialization.\n"
421        "This may either be the path to an existing model checkpoint or the name of a pretrained model.\n"
422        f"The following pretrained models are available: {available_models}.\n"
423        "If not given, the model will be randomly initialized."
424    )
425
426    # More optional argument:
427    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
428    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
429    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
430    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
431    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
432    parser.add_argument("--n_iterations", type=int, default=int(1e5), help="The maximal number of iterations to train for.")  # noqa
433    parser.add_argument("--save_root", help="Root path for saving the checkpoint and log dir.")
434    args = parser.parse_args()
435
436    train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
437        _parse_input_files(args)
438    checkpoint_path = _parse_checkpoint(args.initial_model)
439
440    supervised_training(
441        name=args.name, train_paths=train_image_paths, val_paths=val_image_paths,
442        train_label_paths=train_label_paths, val_label_paths=val_label_paths,
443        raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
444        n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
445        check=args.check, n_iterations=args.n_iterations, save_root=args.save_root,
446        checkpoint_path=checkpoint_path
447    )
def get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
14def get_3d_model(
15    out_channels: int,
16    in_channels: int = 1,
17    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
18    initial_features: int = 32,
19    final_activation: str = "Sigmoid",
20) -> torch.nn.Module:
21    """Get the U-Net model for 3D segmentation tasks.
22
23    Args:
24        out_channels: The number of output channels of the network.
25        scale_factors: The downscaling factors for each level of the U-Net encoder.
26        initial_features: The number of features in the first level of the U-Net.
27            The number of features increases by a factor of two in each level.
28        final_activation: The activation applied to the last output layer.
29
30    Returns:
31        The U-Net.
32    """
33    model = AnisotropicUNet(
34        scale_factors=scale_factors,
35        in_channels=in_channels,
36        out_channels=out_channels,
37        initial_features=initial_features,
38        gain=2,
39        final_activation=final_activation,
40    )
41    return model

Get the U-Net model for 3D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • scale_factors: The downscaling factors for each level of the U-Net encoder.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
44def get_2d_model(
45    out_channels: int,
46    in_channels: int = 1,
47    initial_features: int = 32,
48    final_activation: str = "Sigmoid",
49) -> torch.nn.Module:
50    """Get the U-Net model for 2D segmentation tasks.
51
52    Args:
53        out_channels: The number of output channels of the network.
54        initial_features: The number of features in the first level of the U-Net.
55            The number of features increases by a factor of two in each level.
56        final_activation: The activation applied to the last output layer.
57
58    Returns:
59        The U-Net.
60    """
61    model = UNet2d(
62        in_channels=in_channels,
63        out_channels=out_channels,
64        initial_features=initial_features,
65        gain=2,
66        depth=4,
67        final_activation=final_activation,
68    )
69    return model

Get the U-Net model for 2D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
 91def get_supervised_loader(
 92    data_paths: Tuple[str],
 93    raw_key: str,
 94    label_key: str,
 95    patch_shape: Tuple[int, int, int],
 96    batch_size: int,
 97    n_samples: Optional[int],
 98    add_boundary_transform: bool = True,
 99    label_dtype=torch.float32,
100    rois: Optional[Tuple[Tuple[slice]]] = None,
101    sampler: Optional[callable] = None,
102    ignore_label: Optional[int] = None,
103    label_transform: Optional[callable] = None,
104    label_paths: Optional[Tuple[str]] = None,
105    **loader_kwargs,
106) -> torch.utils.data.DataLoader:
107    """Get a dataloader for supervised segmentation training.
108
109    Args:
110        data_paths: The filepaths to the hdf5 files containing the training data.
111        raw_key: The key that holds the raw data inside of the hdf5.
112        label_key: The key that holds the labels inside of the hdf5.
113        patch_shape: The patch shape used for a training example.
114            In order to run 2d training pass a patch shape with a singleton in the z-axis,
115            e.g. 'patch_shape = [1, 512, 512]'.
116        batch_size: The batch size for training.
117        n_samples: The number of samples per epoch. By default this will be estimated
118            based on the patch_shape and size of the volumes used for training.
119        add_boundary_transform: Whether to add a boundary channel to the training data.
120        label_dtype: The datatype of the labels returned by the dataloader.
121        rois: Optional region of interests for training.
122        sampler: Optional sampler for selecting blocks for training.
123            By default a minimum instance sampler will be used.
124        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
125            ignored in the loss computation. By default this option is not used.
126        label_transform: Label transform that is applied to the segmentation to compute the targets.
127            If no label transform is passed (the default) a boundary transform is used.
128        label_paths: Optional paths containing the labels / annotations for training.
129            If not given, the labels are expected to be contained in the `data_paths`.
130        loader_kwargs: Additional keyword arguments for the dataloader.
131
132    Returns:
133        The PyTorch dataloader.
134    """
135    _, ndim = _determine_ndim(patch_shape)
136    if label_transform is not None:  # A specific label transform was passed, do nothing.
137        pass
138    elif add_boundary_transform:
139        if ignore_label is None:
140            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
141        else:
142            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
143                add_binary_target=True, ignore_label=ignore_label
144            )
145
146    else:
147        if ignore_label is not None:
148            raise NotImplementedError
149        label_transform = torch_em.transform.label.connected_components
150
151    if ndim == 2:
152        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
153        transform = torch_em.transform.Compose(
154            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
155        )
156    else:
157        transform = torch_em.transform.Compose(
158            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
159        )
160
161    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
162    shuffle = loader_kwargs.pop("shuffle", True)
163
164    if sampler is None:
165        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
166
167    if label_paths is None:
168        label_paths = data_paths
169    elif len(label_paths) != len(data_paths):
170        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
171
172    loader = torch_em.default_segmentation_loader(
173        data_paths, raw_key,
174        label_paths, label_key, sampler=sampler,
175        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
176        is_seg_dataset=True, label_transform=label_transform, transform=transform,
177        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
178        label_dtype=label_dtype, rois=rois, **loader_kwargs,
179    )
180    return loader

Get a dataloader for supervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • add_boundary_transform: Whether to add a boundary channel to the training data.
  • label_dtype: The datatype of the labels returned by the dataloader.
  • rois: Optional region of interests for training.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • label_paths: Optional paths containing the labels / annotations for training. If not given, the labels are expected to be contained in the data_paths.
  • loader_kwargs: Additional keyword arguments for the dataloader.
Returns:

The PyTorch dataloader.

def supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, checkpoint_path: Optional[str] = None, **loader_kwargs):
183def supervised_training(
184    name: str,
185    train_paths: Tuple[str],
186    val_paths: Tuple[str],
187    label_key: str,
188    patch_shape: Tuple[int, int, int],
189    save_root: Optional[str] = None,
190    raw_key: str = "raw",
191    batch_size: int = 1,
192    lr: float = 1e-4,
193    n_iterations: int = int(1e5),
194    train_label_paths: Optional[Tuple[str]] = None,
195    val_label_paths: Optional[Tuple[str]] = None,
196    train_rois: Optional[Tuple[Tuple[slice]]] = None,
197    val_rois: Optional[Tuple[Tuple[slice]]] = None,
198    sampler: Optional[callable] = None,
199    n_samples_train: Optional[int] = None,
200    n_samples_val: Optional[int] = None,
201    check: bool = False,
202    ignore_label: Optional[int] = None,
203    label_transform: Optional[callable] = None,
204    in_channels: int = 1,
205    out_channels: int = 2,
206    mask_channel: bool = False,
207    checkpoint_path: Optional[str] = None,
208    **loader_kwargs,
209):
210    """Run supervised segmentation training.
211
212    This function trains a UNet for predicting outputs for segmentation.
213    Expects instance labels and converts them to boundary targets.
214    This behaviour can be changed by passing custom arguments for `label_transform`
215    and/or `out_channels`.
216
217    Args:
218        name: The name for the checkpoint to be trained.
219        train_paths: Filepaths to the hdf5 files for the training data.
220        val_paths: Filepaths to the df5 files for the validation data.
221        label_key: The key that holds the labels inside of the hdf5.
222        patch_shape: The patch shape used for a training example.
223            In order to run 2d training pass a patch shape with a singleton in the z-axis,
224            e.g. 'patch_shape = [1, 512, 512]'.
225        save_root: Folder where the checkpoint will be saved.
226        raw_key: The key that holds the raw data inside of the hdf5.
227        batch_size: The batch size for training.
228        lr: The initial learning rate.
229        n_iterations: The number of iterations to train for.
230        train_label_paths: Optional paths containing the label data for training.
231            If not given, the labels are expected to be part of `train_paths`.
232        val_label_paths: Optional paths containing the label data for validation.
233            If not given, the labels are expected to be part of `val_paths`.
234        train_rois: Optional region of interests for training.
235        val_rois: Optional region of interests for validation.
236        sampler: Optional sampler for selecting blocks for training.
237            By default a minimum instance sampler will be used.
238        n_samples_train: The number of train samples per epoch. By default this will be estimated
239            based on the patch_shape and size of the volumes used for training.
240        n_samples_val: The number of val samples per epoch. By default this will be estimated
241            based on the patch_shape and size of the volumes used for validation.
242        check: Whether to check the training and validation loaders instead of running training.
243        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
244            ignored in the loss computation. By default this option is not used.
245        label_transform: Label transform that is applied to the segmentation to compute the targets.
246            If no label transform is passed (the default) a boundary transform is used.
247        out_channels: The number of output channels of the UNet.
248        mask_channel: Whether the last channels in the labels should be used for masking the loss.
249            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
250        checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
251        loader_kwargs: Additional keyword arguments for the dataloader.
252    """
253    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
254                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
255                                         ignore_label=ignore_label, label_transform=label_transform,
256                                         label_paths=train_label_paths, **loader_kwargs)
257    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
258                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
259                                       ignore_label=ignore_label, label_transform=label_transform,
260                                       label_paths=val_label_paths, **loader_kwargs)
261
262    if check:
263        from torch_em.util.debug import check_loader
264        check_loader(train_loader, n_samples=4)
265        check_loader(val_loader, n_samples=4)
266        return
267
268    is_2d, _ = _determine_ndim(patch_shape)
269    if checkpoint_path is not None:
270        model = torch_em.util.load_model(checkpoint=checkpoint_path)
271    elif is_2d:
272        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
273    else:
274        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
275
276    loss, metric = None, None
277    # No ignore label -> we can use default loss.
278    if ignore_label is None and not mask_channel:
279        pass
280    # If we have an ignore label the loss and metric have to be modified
281    # so that the ignore mask is not used in the gradient calculation.
282    elif ignore_label is not None:
283        loss = torch_em.loss.LossWrapper(
284            loss=torch_em.loss.DiceLoss(),
285            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
286                ignore_label=ignore_label, masking_method="multiply",
287            )
288        )
289        metric = loss
290    elif mask_channel:
291        loss = torch_em.loss.LossWrapper(
292            loss=torch_em.loss.DiceLoss(),
293            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
294                masking_method="crop" if out_channels == 1 else "multiply")
295        )
296        metric = loss
297    else:
298        raise ValueError
299
300    trainer = torch_em.default_segmentation_trainer(
301        name=name,
302        model=model,
303        train_loader=train_loader,
304        val_loader=val_loader,
305        learning_rate=lr,
306        mixed_precision=True,
307        log_image_interval=100,
308        compile_model=False,
309        save_root=save_root,
310        loss=loss,
311        metric=metric,
312    )
313    trainer.fit(n_iterations)

Run supervised segmentation training.

This function trains a UNet for predicting outputs for segmentation. Expects instance labels and converts them to boundary targets. This behaviour can be changed by passing custom arguments for label_transform and/or out_channels.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • train_label_paths: Optional paths containing the label data for training. If not given, the labels are expected to be part of train_paths.
  • val_label_paths: Optional paths containing the label data for validation. If not given, the labels are expected to be part of val_paths.
  • train_rois: Optional region of interests for training.
  • val_rois: Optional region of interests for validation.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • out_channels: The number of output channels of the UNet.
  • mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with ignore_label.
  • checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
  • loader_kwargs: Additional keyword arguments for the dataloader.