synapse_net.training.supervised_training

  1import os
  2from glob import glob
  3from typing import Optional, Tuple
  4
  5import torch
  6import torch_em
  7from sklearn.model_selection import train_test_split
  8from torch_em.model import AnisotropicUNet, UNet2d
  9
 10
 11def get_3d_model(
 12    out_channels: int,
 13    in_channels: int = 1,
 14    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
 15    initial_features: int = 32,
 16    final_activation: str = "Sigmoid",
 17) -> torch.nn.Module:
 18    """Get the U-Net model for 3D segmentation tasks.
 19
 20    Args:
 21        out_channels: The number of output channels of the network.
 22        scale_factors: The downscaling factors for each level of the U-Net encoder.
 23        initial_features: The number of features in the first level of the U-Net.
 24            The number of features increases by a factor of two in each level.
 25        final_activation: The activation applied to the last output layer.
 26
 27    Returns:
 28        The U-Net.
 29    """
 30    model = AnisotropicUNet(
 31        scale_factors=scale_factors,
 32        in_channels=in_channels,
 33        out_channels=out_channels,
 34        initial_features=initial_features,
 35        gain=2,
 36        final_activation=final_activation,
 37    )
 38    return model
 39
 40
 41def get_2d_model(
 42    out_channels: int,
 43    in_channels: int = 1,
 44    initial_features: int = 32,
 45    final_activation: str = "Sigmoid",
 46) -> torch.nn.Module:
 47    """Get the U-Net model for 2D segmentation tasks.
 48
 49    Args:
 50        out_channels: The number of output channels of the network.
 51        initial_features: The number of features in the first level of the U-Net.
 52            The number of features increases by a factor of two in each level.
 53        final_activation: The activation applied to the last output layer.
 54
 55    Returns:
 56        The U-Net.
 57    """
 58    model = UNet2d(
 59        in_channels=in_channels,
 60        out_channels=out_channels,
 61        initial_features=initial_features,
 62        gain=2,
 63        depth=4,
 64        final_activation=final_activation,
 65    )
 66    return model
 67
 68
 69def _adjust_patch_shape(data_shape, patch_shape):
 70    # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape
 71    if data_shape == 2 and len(patch_shape) == 3:
 72        return patch_shape[1:]  # Remove the leading dimension in patch_shape
 73    return patch_shape  # Return the original patch_shape for 3D data
 74
 75
 76def _determine_ndim(patch_shape):
 77    # Check for 2D or 3D training
 78    try:
 79        z, y, x = patch_shape
 80    except ValueError:
 81        y, x = patch_shape
 82        z = 1
 83    is_2d = z == 1
 84    ndim = 2 if is_2d else 3
 85    return is_2d, ndim
 86
 87
 88def get_supervised_loader(
 89    data_paths: Tuple[str],
 90    raw_key: str,
 91    label_key: str,
 92    patch_shape: Tuple[int, int, int],
 93    batch_size: int,
 94    n_samples: Optional[int],
 95    add_boundary_transform: bool = True,
 96    label_dtype=torch.float32,
 97    rois: Optional[Tuple[Tuple[slice]]] = None,
 98    sampler: Optional[callable] = None,
 99    ignore_label: Optional[int] = None,
100    label_transform: Optional[callable] = None,
101    label_paths: Optional[Tuple[str]] = None,
102    **loader_kwargs,
103) -> torch.utils.data.DataLoader:
104    """Get a dataloader for supervised segmentation training.
105
106    Args:
107        data_paths: The filepaths to the hdf5 files containing the training data.
108        raw_key: The key that holds the raw data inside of the hdf5.
109        label_key: The key that holds the labels inside of the hdf5.
110        patch_shape: The patch shape used for a training example.
111            In order to run 2d training pass a patch shape with a singleton in the z-axis,
112            e.g. 'patch_shape = [1, 512, 512]'.
113        batch_size: The batch size for training.
114        n_samples: The number of samples per epoch. By default this will be estimated
115            based on the patch_shape and size of the volumes used for training.
116        add_boundary_transform: Whether to add a boundary channel to the training data.
117        label_dtype: The datatype of the labels returned by the dataloader.
118        rois: Optional region of interests for training.
119        sampler: Optional sampler for selecting blocks for training.
120            By default a minimum instance sampler will be used.
121        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
122            ignored in the loss computation. By default this option is not used.
123        label_transform: Label transform that is applied to the segmentation to compute the targets.
124            If no label transform is passed (the default) a boundary transform is used.
125        label_paths: Optional paths containing the labels / annotations for training.
126            If not given, the labels are expected to be contained in the `data_paths`.
127        loader_kwargs: Additional keyword arguments for the dataloader.
128
129    Returns:
130        The PyTorch dataloader.
131    """
132    _, ndim = _determine_ndim(patch_shape)
133    if label_transform is not None:  # A specific label transform was passed, do nothing.
134        pass
135    elif add_boundary_transform:
136        if ignore_label is None:
137            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
138        else:
139            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
140                add_binary_target=True, ignore_label=ignore_label
141            )
142
143    else:
144        if ignore_label is not None:
145            raise NotImplementedError
146        label_transform = torch_em.transform.label.connected_components
147
148    if ndim == 2:
149        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
150        transform = torch_em.transform.Compose(
151            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
152        )
153    else:
154        transform = torch_em.transform.Compose(
155            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
156        )
157
158    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
159    shuffle = loader_kwargs.pop("shuffle", True)
160
161    if sampler is None:
162        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
163
164    if label_paths is None:
165        label_paths = data_paths
166    elif len(label_paths) != len(data_paths):
167        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
168
169    loader = torch_em.default_segmentation_loader(
170        data_paths, raw_key,
171        label_paths, label_key, sampler=sampler,
172        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
173        is_seg_dataset=True, label_transform=label_transform, transform=transform,
174        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
175        label_dtype=label_dtype, rois=rois, **loader_kwargs,
176    )
177    return loader
178
179
180def supervised_training(
181    name: str,
182    train_paths: Tuple[str],
183    val_paths: Tuple[str],
184    label_key: str,
185    patch_shape: Tuple[int, int, int],
186    save_root: Optional[str] = None,
187    raw_key: str = "raw",
188    batch_size: int = 1,
189    lr: float = 1e-4,
190    n_iterations: int = int(1e5),
191    train_label_paths: Optional[Tuple[str]] = None,
192    val_label_paths: Optional[Tuple[str]] = None,
193    train_rois: Optional[Tuple[Tuple[slice]]] = None,
194    val_rois: Optional[Tuple[Tuple[slice]]] = None,
195    sampler: Optional[callable] = None,
196    n_samples_train: Optional[int] = None,
197    n_samples_val: Optional[int] = None,
198    check: bool = False,
199    ignore_label: Optional[int] = None,
200    label_transform: Optional[callable] = None,
201    in_channels: int = 1,
202    out_channels: int = 2,
203    mask_channel: bool = False,
204    **loader_kwargs,
205):
206    """Run supervised segmentation training.
207
208    This function trains a UNet for predicting outputs for segmentation.
209    Expects instance labels and converts them to boundary targets.
210    This behaviour can be changed by passing custom arguments for `label_transform`
211    and/or `out_channels`.
212
213    Args:
214        name: The name for the checkpoint to be trained.
215        train_paths: Filepaths to the hdf5 files for the training data.
216        val_paths: Filepaths to the df5 files for the validation data.
217        label_key: The key that holds the labels inside of the hdf5.
218        patch_shape: The patch shape used for a training example.
219            In order to run 2d training pass a patch shape with a singleton in the z-axis,
220            e.g. 'patch_shape = [1, 512, 512]'.
221        save_root: Folder where the checkpoint will be saved.
222        raw_key: The key that holds the raw data inside of the hdf5.
223        batch_size: The batch size for training.
224        lr: The initial learning rate.
225        n_iterations: The number of iterations to train for.
226        train_label_paths: Optional paths containing the label data for training.
227            If not given, the labels are expected to be part of `train_paths`.
228        val_label_paths: Optional paths containing the label data for validation.
229            If not given, the labels are expected to be part of `val_paths`.
230        train_rois: Optional region of interests for training.
231        val_rois: Optional region of interests for validation.
232        sampler: Optional sampler for selecting blocks for training.
233            By default a minimum instance sampler will be used.
234        n_samples_train: The number of train samples per epoch. By default this will be estimated
235            based on the patch_shape and size of the volumes used for training.
236        n_samples_val: The number of val samples per epoch. By default this will be estimated
237            based on the patch_shape and size of the volumes used for validation.
238        check: Whether to check the training and validation loaders instead of running training.
239        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
240            ignored in the loss computation. By default this option is not used.
241        label_transform: Label transform that is applied to the segmentation to compute the targets.
242            If no label transform is passed (the default) a boundary transform is used.
243        out_channels: The number of output channels of the UNet.
244        mask_channel: Whether the last channels in the labels should be used for masking the loss.
245            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
246        loader_kwargs: Additional keyword arguments for the dataloader.
247    """
248    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
249                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
250                                         ignore_label=ignore_label, label_transform=label_transform,
251                                         label_paths=train_label_paths, **loader_kwargs)
252    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
253                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
254                                       ignore_label=ignore_label, label_transform=label_transform,
255                                       label_paths=val_label_paths, **loader_kwargs)
256
257    if check:
258        from torch_em.util.debug import check_loader
259        check_loader(train_loader, n_samples=4)
260        check_loader(val_loader, n_samples=4)
261        return
262
263    is_2d, _ = _determine_ndim(patch_shape)
264    if is_2d:
265        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
266    else:
267        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
268
269    loss, metric = None, None
270    # No ignore label -> we can use default loss.
271    if ignore_label is None and not mask_channel:
272        pass
273    # If we have an ignore label the loss and metric have to be modified
274    # so that the ignore mask is not used in the gradient calculation.
275    elif ignore_label is not None:
276        loss = torch_em.loss.LossWrapper(
277            loss=torch_em.loss.DiceLoss(),
278            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
279                ignore_label=ignore_label, masking_method="multiply",
280            )
281        )
282        metric = loss
283    elif mask_channel:
284        loss = torch_em.loss.LossWrapper(
285            loss=torch_em.loss.DiceLoss(),
286            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
287                masking_method="crop" if out_channels == 1 else "multiply")
288        )
289        metric = loss
290    else:
291        raise ValueError
292
293    trainer = torch_em.default_segmentation_trainer(
294        name=name,
295        model=model,
296        train_loader=train_loader,
297        val_loader=val_loader,
298        learning_rate=lr,
299        mixed_precision=True,
300        log_image_interval=100,
301        compile_model=False,
302        save_root=save_root,
303        loss=loss,
304        metric=metric,
305    )
306    trainer.fit(n_iterations)
307
308
309def _derive_key_from_files(files, key):
310    # Get all file extensions (general wild-cards may pick up files with multiple extensions).
311    extensions = list(set([os.path.splitext(ff)[1] for ff in files]))
312
313    # If we have more than 1 file extension we just use the key that was passed,
314    # as it is unclear how to derive a consistent key.
315    if len(extensions) > 1:
316        return files, key
317
318    ext = extensions[0]
319    extension_to_key = {".tif": None, ".mrc": "data", ".rec": "data"}
320
321    # Derive the key from the extension if the key is None.
322    if key is None and ext in extension_to_key:
323        key = extension_to_key[ext]
324    # If the key is None and can't be derived raise an error.
325    elif key is None and ext not in extension_to_key:
326        raise ValueError(
327            f"You have not passed a key for the data in {ext} format, for which the key cannot be derived."
328        )
329    # If the key was passed and doesn't match the extension raise an error.
330    elif key is not None and ext in extension_to_key and key != extension_to_key[ext]:
331        raise ValueError(
332            f"The expected key {extension_to_key[ext]} for format {ext} did not match the passed key {key}."
333        )
334    return files, key
335
336
337def _parse_input_folder(folder, pattern, key):
338    files = sorted(glob(os.path.join(folder, "**", pattern), recursive=True))
339    return _derive_key_from_files(files, key)
340
341
342def _parse_input_files(args):
343    train_image_paths, raw_key = _parse_input_folder(args.train_folder, args.image_file_pattern, args.raw_key)
344    train_label_paths, label_key = _parse_input_folder(args.label_folder, args.label_file_pattern, args.label_key)
345    if len(train_image_paths) != len(train_label_paths):
346        raise ValueError(
347            f"The image and label paths parsed from {args.train_folder} and {args.label_folder} don't match."
348            f"The image folder contains {len(train_image_paths)}, the label folder contains {len(train_label_paths)}."
349        )
350
351    if args.val_folder is None:
352        if args.val_label_folder is not None:
353            raise ValueError("You have passed a val_label_folder, but not a val_folder.")
354        train_image_paths, val_image_paths, train_label_paths, val_label_paths = train_test_split(
355            train_image_paths, train_label_paths, test_size=args.val_fraction, random_state=42
356        )
357    else:
358        if args.val_label_folder is None:
359            raise ValueError("You have passed a val_folder, but not a val_label_folder.")
360        val_image_paths = _parse_input_folder(args.val_image_folder, args.image_file_pattern, raw_key)
361        val_label_paths = _parse_input_folder(args.val_label_folder, args.label_file_pattern, label_key)
362
363    return train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key
364
365
366# TODO enable initialization with a pre-trained model.
367def main():
368    """@private
369    """
370    import argparse
371
372    parser = argparse.ArgumentParser(
373        description="Train a model for foreground and boundary segmentation via supervised learning.\n\n"
374        "You can use this function to train a model for vesicle segmentation, or another segmentation task, like this:\n"  # noqa
375        "synapse_net.run_supervised_training -n my_model -i /path/to/images -l /path/to/labels --patch_shape 32 192 192\n"  # noqa
376        "The trained model will be saved in the folder 'checkpoints/my_model' (or whichever name you pass to the '-n' argument)."  # noqa
377        "You can then use this model for segmentation with the SynapseNet GUI or CLI. "
378        "Check out the information below for details on the arguments of this function.",
379        formatter_class=argparse.RawTextHelpFormatter
380    )
381    parser.add_argument("-n", "--name", required=True, help="The name of the model to be trained.")
382    parser.add_argument("-p", "--patch_shape", nargs=3, type=int, help="The patch shape for training.")
383
384    # Folders with training data, containing raw/image data and labels.
385    parser.add_argument("-i", "--train_folder", required=True, help="The input folder with the training image data.")
386    parser.add_argument("--image_file_pattern", default="*",
387                        help="The pattern for selecting image files. For example, '*.mrc' to select all mrc files.")
388    parser.add_argument("--raw_key",
389                        help="The internal path for the raw data. If not given, will be determined based on the file extension.")  # noqa
390    parser.add_argument("-l", "--label_folder", required=True, help="The input folder with the training labels.")
391    parser.add_argument("--label_file_pattern", default="*",
392                        help="The pattern for selecting label files. For example, '*.tif' to select all tif files.")
393    parser.add_argument("--label_key",
394                        help="The internal path for the label data. If not given, will be determined based on the file extension.")  # noqa
395
396    # Optional folders with validation data. If not given the training data is split into train/val.
397    parser.add_argument("--val_folder",
398                        help="The input folder with the validation data. If not given the training data will be split for validation")  # noqa
399    parser.add_argument("--val_label_folder",
400                        help="The input folder with the validation labels. If not given the training data will be split for validation.")  # noqa
401
402    # More optional argument:
403    parser.add_argument("--batch_size", type=int, default=1, help="The batch size for training.")
404    parser.add_argument("--n_samples_train", type=int, help="The number of samples per epoch for training. If not given will be derived from the data size.")  # noqa
405    parser.add_argument("--n_samples_val", type=int, help="The number of samples per epoch for validation. If not given will be derived from the data size.")  # noqa
406    parser.add_argument("--val_fraction", type=float, default=0.15, help="The fraction of the data to use for validation. This has no effect if 'val_folder' and 'val_label_folder' were passed.")  # noqa
407    parser.add_argument("--check", action="store_true", help="Visualize samples from the data loaders to ensure correct data instead of running training.")  # noqa
408    args = parser.parse_args()
409
410    train_image_paths, train_label_paths, val_image_paths, val_label_paths, raw_key, label_key =\
411        _parse_input_files(args)
412
413    supervised_training(
414        name=args.name, train_paths=train_image_paths, val_paths=val_image_paths,
415        train_label_paths=train_label_paths, val_label_paths=val_label_paths,
416        raw_key=raw_key, label_key=label_key, patch_shape=args.patch_shape, batch_size=args.batch_size,
417        n_samples_train=args.n_samples_train, n_samples_val=args.n_samples_val,
418        check=args.check,
419    )
def get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
12def get_3d_model(
13    out_channels: int,
14    in_channels: int = 1,
15    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
16    initial_features: int = 32,
17    final_activation: str = "Sigmoid",
18) -> torch.nn.Module:
19    """Get the U-Net model for 3D segmentation tasks.
20
21    Args:
22        out_channels: The number of output channels of the network.
23        scale_factors: The downscaling factors for each level of the U-Net encoder.
24        initial_features: The number of features in the first level of the U-Net.
25            The number of features increases by a factor of two in each level.
26        final_activation: The activation applied to the last output layer.
27
28    Returns:
29        The U-Net.
30    """
31    model = AnisotropicUNet(
32        scale_factors=scale_factors,
33        in_channels=in_channels,
34        out_channels=out_channels,
35        initial_features=initial_features,
36        gain=2,
37        final_activation=final_activation,
38    )
39    return model

Get the U-Net model for 3D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • scale_factors: The downscaling factors for each level of the U-Net encoder.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
42def get_2d_model(
43    out_channels: int,
44    in_channels: int = 1,
45    initial_features: int = 32,
46    final_activation: str = "Sigmoid",
47) -> torch.nn.Module:
48    """Get the U-Net model for 2D segmentation tasks.
49
50    Args:
51        out_channels: The number of output channels of the network.
52        initial_features: The number of features in the first level of the U-Net.
53            The number of features increases by a factor of two in each level.
54        final_activation: The activation applied to the last output layer.
55
56    Returns:
57        The U-Net.
58    """
59    model = UNet2d(
60        in_channels=in_channels,
61        out_channels=out_channels,
62        initial_features=initial_features,
63        gain=2,
64        depth=4,
65        final_activation=final_activation,
66    )
67    return model

Get the U-Net model for 2D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, label_paths: Optional[Tuple[str]] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
 89def get_supervised_loader(
 90    data_paths: Tuple[str],
 91    raw_key: str,
 92    label_key: str,
 93    patch_shape: Tuple[int, int, int],
 94    batch_size: int,
 95    n_samples: Optional[int],
 96    add_boundary_transform: bool = True,
 97    label_dtype=torch.float32,
 98    rois: Optional[Tuple[Tuple[slice]]] = None,
 99    sampler: Optional[callable] = None,
100    ignore_label: Optional[int] = None,
101    label_transform: Optional[callable] = None,
102    label_paths: Optional[Tuple[str]] = None,
103    **loader_kwargs,
104) -> torch.utils.data.DataLoader:
105    """Get a dataloader for supervised segmentation training.
106
107    Args:
108        data_paths: The filepaths to the hdf5 files containing the training data.
109        raw_key: The key that holds the raw data inside of the hdf5.
110        label_key: The key that holds the labels inside of the hdf5.
111        patch_shape: The patch shape used for a training example.
112            In order to run 2d training pass a patch shape with a singleton in the z-axis,
113            e.g. 'patch_shape = [1, 512, 512]'.
114        batch_size: The batch size for training.
115        n_samples: The number of samples per epoch. By default this will be estimated
116            based on the patch_shape and size of the volumes used for training.
117        add_boundary_transform: Whether to add a boundary channel to the training data.
118        label_dtype: The datatype of the labels returned by the dataloader.
119        rois: Optional region of interests for training.
120        sampler: Optional sampler for selecting blocks for training.
121            By default a minimum instance sampler will be used.
122        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
123            ignored in the loss computation. By default this option is not used.
124        label_transform: Label transform that is applied to the segmentation to compute the targets.
125            If no label transform is passed (the default) a boundary transform is used.
126        label_paths: Optional paths containing the labels / annotations for training.
127            If not given, the labels are expected to be contained in the `data_paths`.
128        loader_kwargs: Additional keyword arguments for the dataloader.
129
130    Returns:
131        The PyTorch dataloader.
132    """
133    _, ndim = _determine_ndim(patch_shape)
134    if label_transform is not None:  # A specific label transform was passed, do nothing.
135        pass
136    elif add_boundary_transform:
137        if ignore_label is None:
138            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
139        else:
140            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
141                add_binary_target=True, ignore_label=ignore_label
142            )
143
144    else:
145        if ignore_label is not None:
146            raise NotImplementedError
147        label_transform = torch_em.transform.label.connected_components
148
149    if ndim == 2:
150        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
151        transform = torch_em.transform.Compose(
152            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
153        )
154    else:
155        transform = torch_em.transform.Compose(
156            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
157        )
158
159    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
160    shuffle = loader_kwargs.pop("shuffle", True)
161
162    if sampler is None:
163        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
164
165    if label_paths is None:
166        label_paths = data_paths
167    elif len(label_paths) != len(data_paths):
168        raise ValueError(f"Data paths and label paths don't match: {len(data_paths)} != {len(label_paths)}")
169
170    loader = torch_em.default_segmentation_loader(
171        data_paths, raw_key,
172        label_paths, label_key, sampler=sampler,
173        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
174        is_seg_dataset=True, label_transform=label_transform, transform=transform,
175        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
176        label_dtype=label_dtype, rois=rois, **loader_kwargs,
177    )
178    return loader

Get a dataloader for supervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • add_boundary_transform: Whether to add a boundary channel to the training data.
  • label_dtype: The datatype of the labels returned by the dataloader.
  • rois: Optional region of interests for training.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • label_paths: Optional paths containing the labels / annotations for training. If not given, the labels are expected to be contained in the data_paths.
  • loader_kwargs: Additional keyword arguments for the dataloader.
Returns:

The PyTorch dataloader.

def supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_label_paths: Optional[Tuple[str]] = None, val_label_paths: Optional[Tuple[str]] = None, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, **loader_kwargs):
181def supervised_training(
182    name: str,
183    train_paths: Tuple[str],
184    val_paths: Tuple[str],
185    label_key: str,
186    patch_shape: Tuple[int, int, int],
187    save_root: Optional[str] = None,
188    raw_key: str = "raw",
189    batch_size: int = 1,
190    lr: float = 1e-4,
191    n_iterations: int = int(1e5),
192    train_label_paths: Optional[Tuple[str]] = None,
193    val_label_paths: Optional[Tuple[str]] = None,
194    train_rois: Optional[Tuple[Tuple[slice]]] = None,
195    val_rois: Optional[Tuple[Tuple[slice]]] = None,
196    sampler: Optional[callable] = None,
197    n_samples_train: Optional[int] = None,
198    n_samples_val: Optional[int] = None,
199    check: bool = False,
200    ignore_label: Optional[int] = None,
201    label_transform: Optional[callable] = None,
202    in_channels: int = 1,
203    out_channels: int = 2,
204    mask_channel: bool = False,
205    **loader_kwargs,
206):
207    """Run supervised segmentation training.
208
209    This function trains a UNet for predicting outputs for segmentation.
210    Expects instance labels and converts them to boundary targets.
211    This behaviour can be changed by passing custom arguments for `label_transform`
212    and/or `out_channels`.
213
214    Args:
215        name: The name for the checkpoint to be trained.
216        train_paths: Filepaths to the hdf5 files for the training data.
217        val_paths: Filepaths to the df5 files for the validation data.
218        label_key: The key that holds the labels inside of the hdf5.
219        patch_shape: The patch shape used for a training example.
220            In order to run 2d training pass a patch shape with a singleton in the z-axis,
221            e.g. 'patch_shape = [1, 512, 512]'.
222        save_root: Folder where the checkpoint will be saved.
223        raw_key: The key that holds the raw data inside of the hdf5.
224        batch_size: The batch size for training.
225        lr: The initial learning rate.
226        n_iterations: The number of iterations to train for.
227        train_label_paths: Optional paths containing the label data for training.
228            If not given, the labels are expected to be part of `train_paths`.
229        val_label_paths: Optional paths containing the label data for validation.
230            If not given, the labels are expected to be part of `val_paths`.
231        train_rois: Optional region of interests for training.
232        val_rois: Optional region of interests for validation.
233        sampler: Optional sampler for selecting blocks for training.
234            By default a minimum instance sampler will be used.
235        n_samples_train: The number of train samples per epoch. By default this will be estimated
236            based on the patch_shape and size of the volumes used for training.
237        n_samples_val: The number of val samples per epoch. By default this will be estimated
238            based on the patch_shape and size of the volumes used for validation.
239        check: Whether to check the training and validation loaders instead of running training.
240        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
241            ignored in the loss computation. By default this option is not used.
242        label_transform: Label transform that is applied to the segmentation to compute the targets.
243            If no label transform is passed (the default) a boundary transform is used.
244        out_channels: The number of output channels of the UNet.
245        mask_channel: Whether the last channels in the labels should be used for masking the loss.
246            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
247        loader_kwargs: Additional keyword arguments for the dataloader.
248    """
249    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
250                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
251                                         ignore_label=ignore_label, label_transform=label_transform,
252                                         label_paths=train_label_paths, **loader_kwargs)
253    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
254                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
255                                       ignore_label=ignore_label, label_transform=label_transform,
256                                       label_paths=val_label_paths, **loader_kwargs)
257
258    if check:
259        from torch_em.util.debug import check_loader
260        check_loader(train_loader, n_samples=4)
261        check_loader(val_loader, n_samples=4)
262        return
263
264    is_2d, _ = _determine_ndim(patch_shape)
265    if is_2d:
266        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
267    else:
268        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
269
270    loss, metric = None, None
271    # No ignore label -> we can use default loss.
272    if ignore_label is None and not mask_channel:
273        pass
274    # If we have an ignore label the loss and metric have to be modified
275    # so that the ignore mask is not used in the gradient calculation.
276    elif ignore_label is not None:
277        loss = torch_em.loss.LossWrapper(
278            loss=torch_em.loss.DiceLoss(),
279            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
280                ignore_label=ignore_label, masking_method="multiply",
281            )
282        )
283        metric = loss
284    elif mask_channel:
285        loss = torch_em.loss.LossWrapper(
286            loss=torch_em.loss.DiceLoss(),
287            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
288                masking_method="crop" if out_channels == 1 else "multiply")
289        )
290        metric = loss
291    else:
292        raise ValueError
293
294    trainer = torch_em.default_segmentation_trainer(
295        name=name,
296        model=model,
297        train_loader=train_loader,
298        val_loader=val_loader,
299        learning_rate=lr,
300        mixed_precision=True,
301        log_image_interval=100,
302        compile_model=False,
303        save_root=save_root,
304        loss=loss,
305        metric=metric,
306    )
307    trainer.fit(n_iterations)

Run supervised segmentation training.

This function trains a UNet for predicting outputs for segmentation. Expects instance labels and converts them to boundary targets. This behaviour can be changed by passing custom arguments for label_transform and/or out_channels.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • train_label_paths: Optional paths containing the label data for training. If not given, the labels are expected to be part of train_paths.
  • val_label_paths: Optional paths containing the label data for validation. If not given, the labels are expected to be part of val_paths.
  • train_rois: Optional region of interests for training.
  • val_rois: Optional region of interests for validation.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • out_channels: The number of output channels of the UNet.
  • mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with ignore_label.
  • loader_kwargs: Additional keyword arguments for the dataloader.