synapse_net.training.supervised_training

  1from typing import Optional, Tuple
  2
  3import torch
  4import torch_em
  5from torch_em.model import AnisotropicUNet, UNet2d
  6
  7
  8def get_3d_model(
  9    out_channels: int,
 10    in_channels: int = 1,
 11    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
 12    initial_features: int = 32,
 13    final_activation: str = "Sigmoid",
 14) -> torch.nn.Module:
 15    """Get the U-Net model for 3D segmentation tasks.
 16
 17    Args:
 18        out_channels: The number of output channels of the network.
 19        scale_factors: The downscaling factors for each level of the U-Net encoder.
 20        initial_features: The number of features in the first level of the U-Net.
 21            The number of features increases by a factor of two in each level.
 22        final_activation: The activation applied to the last output layer.
 23
 24    Returns:
 25        The U-Net.
 26    """
 27    model = AnisotropicUNet(
 28        scale_factors=scale_factors,
 29        in_channels=in_channels,
 30        out_channels=out_channels,
 31        initial_features=initial_features,
 32        gain=2,
 33        final_activation=final_activation,
 34    )
 35    return model
 36
 37
 38def get_2d_model(
 39    out_channels: int,
 40    in_channels: int = 1,
 41    initial_features: int = 32,
 42    final_activation: str = "Sigmoid",
 43) -> torch.nn.Module:
 44    """Get the U-Net model for 2D segmentation tasks.
 45
 46    Args:
 47        out_channels: The number of output channels of the network.
 48        initial_features: The number of features in the first level of the U-Net.
 49            The number of features increases by a factor of two in each level.
 50        final_activation: The activation applied to the last output layer.
 51
 52    Returns:
 53        The U-Net.
 54    """
 55    model = UNet2d(
 56        in_channels=in_channels,
 57        out_channels=out_channels,
 58        initial_features=initial_features,
 59        gain=2,
 60        depth=4,
 61        final_activation=final_activation,
 62    )
 63    return model
 64
 65
 66def _adjust_patch_shape(data_shape, patch_shape):
 67    # If data is 2D and patch_shape is 3D, drop the extra dimension in patch_shape
 68    if data_shape == 2 and len(patch_shape) == 3:
 69        return patch_shape[1:]  # Remove the leading dimension in patch_shape
 70    return patch_shape  # Return the original patch_shape for 3D data
 71
 72
 73def _determine_ndim(patch_shape):
 74    # Check for 2D or 3D training
 75    try:
 76        z, y, x = patch_shape
 77    except ValueError:
 78        y, x = patch_shape
 79        z = 1
 80    is_2d = z == 1
 81    ndim = 2 if is_2d else 3
 82    return is_2d, ndim
 83
 84
 85def get_supervised_loader(
 86    data_paths: Tuple[str],
 87    raw_key: str,
 88    label_key: str,
 89    patch_shape: Tuple[int, int, int],
 90    batch_size: int,
 91    n_samples: Optional[int],
 92    add_boundary_transform: bool = True,
 93    label_dtype=torch.float32,
 94    rois: Optional[Tuple[Tuple[slice]]] = None,
 95    sampler: Optional[callable] = None,
 96    ignore_label: Optional[int] = None,
 97    label_transform: Optional[callable] = None,
 98    **loader_kwargs,
 99) -> torch.utils.data.DataLoader:
100    """Get a dataloader for supervised segmentation training.
101
102    Args:
103        data_paths: The filepaths to the hdf5 files containing the training data.
104        raw_key: The key that holds the raw data inside of the hdf5.
105        label_key: The key that holds the labels inside of the hdf5.
106        patch_shape: The patch shape used for a training example.
107            In order to run 2d training pass a patch shape with a singleton in the z-axis,
108            e.g. 'patch_shape = [1, 512, 512]'.
109        batch_size: The batch size for training.
110        n_samples: The number of samples per epoch. By default this will be estimated
111            based on the patch_shape and size of the volumes used for training.
112        add_boundary_transform: Whether to add a boundary channel to the training data.
113        label_dtype: The datatype of the labels returned by the dataloader.
114        rois: Optional region of interests for training.
115        sampler: Optional sampler for selecting blocks for training.
116            By default a minimum instance sampler will be used.
117        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
118            ignored in the loss computation. By default this option is not used.
119        label_transform: Label transform that is applied to the segmentation to compute the targets.
120            If no label transform is passed (the default) a boundary transform is used.
121        loader_kwargs: Additional keyword arguments for the dataloader.
122
123    Returns:
124        The PyTorch dataloader.
125    """
126    _, ndim = _determine_ndim(patch_shape)
127    if label_transform is not None:  # A specific label transform was passed, do nothing.
128        pass
129    elif add_boundary_transform:
130        if ignore_label is None:
131            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
132        else:
133            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
134                add_binary_target=True, ignore_label=ignore_label
135            )
136
137    else:
138        if ignore_label is not None:
139            raise NotImplementedError
140        label_transform = torch_em.transform.label.connected_components
141
142    if ndim == 2:
143        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
144        transform = torch_em.transform.Compose(
145            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
146        )
147    else:
148        transform = torch_em.transform.Compose(
149            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
150        )
151
152    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
153    shuffle = loader_kwargs.pop("shuffle", True)
154
155    if sampler is None:
156        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
157
158    loader = torch_em.default_segmentation_loader(
159        data_paths, raw_key,
160        data_paths, label_key, sampler=sampler,
161        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
162        is_seg_dataset=True, label_transform=label_transform, transform=transform,
163        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
164        label_dtype=label_dtype, rois=rois, **loader_kwargs,
165    )
166    return loader
167
168
169def supervised_training(
170    name: str,
171    train_paths: Tuple[str],
172    val_paths: Tuple[str],
173    label_key: str,
174    patch_shape: Tuple[int, int, int],
175    save_root: Optional[str] = None,
176    raw_key: str = "raw",
177    batch_size: int = 1,
178    lr: float = 1e-4,
179    n_iterations: int = int(1e5),
180    train_rois: Optional[Tuple[Tuple[slice]]] = None,
181    val_rois: Optional[Tuple[Tuple[slice]]] = None,
182    sampler: Optional[callable] = None,
183    n_samples_train: Optional[int] = None,
184    n_samples_val: Optional[int] = None,
185    check: bool = False,
186    ignore_label: Optional[int] = None,
187    label_transform: Optional[callable] = None,
188    in_channels: int = 1,
189    out_channels: int = 2,
190    mask_channel: bool = False,
191    **loader_kwargs,
192):
193    """Run supervised segmentation training.
194
195    This function trains a UNet for predicting outputs for segmentation.
196    Expects instance labels and converts them to boundary targets.
197    This behaviour can be changed by passing custom arguments for `label_transform`
198    and/or `out_channels`.
199
200    Args:
201        name: The name for the checkpoint to be trained.
202        train_paths: Filepaths to the hdf5 files for the training data.
203        val_paths: Filepaths to the df5 files for the validation data.
204        label_key: The key that holds the labels inside of the hdf5.
205        patch_shape: The patch shape used for a training example.
206            In order to run 2d training pass a patch shape with a singleton in the z-axis,
207            e.g. 'patch_shape = [1, 512, 512]'.
208        save_root: Folder where the checkpoint will be saved.
209        raw_key: The key that holds the raw data inside of the hdf5.
210        batch_size: The batch size for training.
211        lr: The initial learning rate.
212        n_iterations: The number of iterations to train for.
213        train_rois: Optional region of interests for training.
214        val_rois: Optional region of interests for validation.
215        sampler: Optional sampler for selecting blocks for training.
216            By default a minimum instance sampler will be used.
217        n_samples_train: The number of train samples per epoch. By default this will be estimated
218            based on the patch_shape and size of the volumes used for training.
219        n_samples_val: The number of val samples per epoch. By default this will be estimated
220            based on the patch_shape and size of the volumes used for validation.
221        check: Whether to check the training and validation loaders instead of running training.
222        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
223            ignored in the loss computation. By default this option is not used.
224        label_transform: Label transform that is applied to the segmentation to compute the targets.
225            If no label transform is passed (the default) a boundary transform is used.
226        out_channels: The number of output channels of the UNet.
227        mask_channel: Whether the last channels in the labels should be used for masking the loss.
228            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
229        loader_kwargs: Additional keyword arguments for the dataloader.
230    """
231    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
232                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
233                                         ignore_label=ignore_label, label_transform=label_transform,
234                                         **loader_kwargs)
235    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
236                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
237                                       ignore_label=ignore_label, label_transform=label_transform,
238                                       **loader_kwargs)
239
240    if check:
241        from torch_em.util.debug import check_loader
242        check_loader(train_loader, n_samples=4)
243        check_loader(val_loader, n_samples=4)
244        return
245
246    is_2d, _ = _determine_ndim(patch_shape)
247    if is_2d:
248        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
249    else:
250        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
251
252    loss, metric = None, None
253    # No ignore label -> we can use default loss.
254    if ignore_label is None and not mask_channel:
255        pass
256    # If we have an ignore label the loss and metric have to be modified
257    # so that the ignore mask is not used in the gradient calculation.
258    elif ignore_label is not None:
259        loss = torch_em.loss.LossWrapper(
260            loss=torch_em.loss.DiceLoss(),
261            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
262                ignore_label=ignore_label, masking_method="multiply",
263            )
264        )
265        metric = loss
266    elif mask_channel:
267        loss = torch_em.loss.LossWrapper(
268            loss=torch_em.loss.DiceLoss(),
269            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
270                masking_method="crop" if out_channels == 1 else "multiply")
271        )
272        metric = loss
273    else:
274        raise ValueError
275
276    trainer = torch_em.default_segmentation_trainer(
277        name=name,
278        model=model,
279        train_loader=train_loader,
280        val_loader=val_loader,
281        learning_rate=lr,
282        mixed_precision=True,
283        log_image_interval=100,
284        compile_model=False,
285        save_root=save_root,
286        loss=loss,
287        metric=metric,
288    )
289    trainer.fit(n_iterations)
def get_3d_model( out_channels: int, in_channels: int = 1, scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
 9def get_3d_model(
10    out_channels: int,
11    in_channels: int = 1,
12    scale_factors: Tuple[Tuple[int, int, int]] = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
13    initial_features: int = 32,
14    final_activation: str = "Sigmoid",
15) -> torch.nn.Module:
16    """Get the U-Net model for 3D segmentation tasks.
17
18    Args:
19        out_channels: The number of output channels of the network.
20        scale_factors: The downscaling factors for each level of the U-Net encoder.
21        initial_features: The number of features in the first level of the U-Net.
22            The number of features increases by a factor of two in each level.
23        final_activation: The activation applied to the last output layer.
24
25    Returns:
26        The U-Net.
27    """
28    model = AnisotropicUNet(
29        scale_factors=scale_factors,
30        in_channels=in_channels,
31        out_channels=out_channels,
32        initial_features=initial_features,
33        gain=2,
34        final_activation=final_activation,
35    )
36    return model

Get the U-Net model for 3D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • scale_factors: The downscaling factors for each level of the U-Net encoder.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_2d_model( out_channels: int, in_channels: int = 1, initial_features: int = 32, final_activation: str = 'Sigmoid') -> torch.nn.modules.module.Module:
39def get_2d_model(
40    out_channels: int,
41    in_channels: int = 1,
42    initial_features: int = 32,
43    final_activation: str = "Sigmoid",
44) -> torch.nn.Module:
45    """Get the U-Net model for 2D segmentation tasks.
46
47    Args:
48        out_channels: The number of output channels of the network.
49        initial_features: The number of features in the first level of the U-Net.
50            The number of features increases by a factor of two in each level.
51        final_activation: The activation applied to the last output layer.
52
53    Returns:
54        The U-Net.
55    """
56    model = UNet2d(
57        in_channels=in_channels,
58        out_channels=out_channels,
59        initial_features=initial_features,
60        gain=2,
61        depth=4,
62        final_activation=final_activation,
63    )
64    return model

Get the U-Net model for 2D segmentation tasks.

Arguments:
  • out_channels: The number of output channels of the network.
  • initial_features: The number of features in the first level of the U-Net. The number of features increases by a factor of two in each level.
  • final_activation: The activation applied to the last output layer.
Returns:

The U-Net.

def get_supervised_loader( data_paths: Tuple[str], raw_key: str, label_key: str, patch_shape: Tuple[int, int, int], batch_size: int, n_samples: Optional[int], add_boundary_transform: bool = True, label_dtype=torch.float32, rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, **loader_kwargs) -> torch.utils.data.dataloader.DataLoader:
 86def get_supervised_loader(
 87    data_paths: Tuple[str],
 88    raw_key: str,
 89    label_key: str,
 90    patch_shape: Tuple[int, int, int],
 91    batch_size: int,
 92    n_samples: Optional[int],
 93    add_boundary_transform: bool = True,
 94    label_dtype=torch.float32,
 95    rois: Optional[Tuple[Tuple[slice]]] = None,
 96    sampler: Optional[callable] = None,
 97    ignore_label: Optional[int] = None,
 98    label_transform: Optional[callable] = None,
 99    **loader_kwargs,
100) -> torch.utils.data.DataLoader:
101    """Get a dataloader for supervised segmentation training.
102
103    Args:
104        data_paths: The filepaths to the hdf5 files containing the training data.
105        raw_key: The key that holds the raw data inside of the hdf5.
106        label_key: The key that holds the labels inside of the hdf5.
107        patch_shape: The patch shape used for a training example.
108            In order to run 2d training pass a patch shape with a singleton in the z-axis,
109            e.g. 'patch_shape = [1, 512, 512]'.
110        batch_size: The batch size for training.
111        n_samples: The number of samples per epoch. By default this will be estimated
112            based on the patch_shape and size of the volumes used for training.
113        add_boundary_transform: Whether to add a boundary channel to the training data.
114        label_dtype: The datatype of the labels returned by the dataloader.
115        rois: Optional region of interests for training.
116        sampler: Optional sampler for selecting blocks for training.
117            By default a minimum instance sampler will be used.
118        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
119            ignored in the loss computation. By default this option is not used.
120        label_transform: Label transform that is applied to the segmentation to compute the targets.
121            If no label transform is passed (the default) a boundary transform is used.
122        loader_kwargs: Additional keyword arguments for the dataloader.
123
124    Returns:
125        The PyTorch dataloader.
126    """
127    _, ndim = _determine_ndim(patch_shape)
128    if label_transform is not None:  # A specific label transform was passed, do nothing.
129        pass
130    elif add_boundary_transform:
131        if ignore_label is None:
132            label_transform = torch_em.transform.BoundaryTransform(add_binary_target=True)
133        else:
134            label_transform = torch_em.transform.label.BoundaryTransformWithIgnoreLabel(
135                add_binary_target=True, ignore_label=ignore_label
136            )
137
138    else:
139        if ignore_label is not None:
140            raise NotImplementedError
141        label_transform = torch_em.transform.label.connected_components
142
143    if ndim == 2:
144        adjusted_patch_shape = _adjust_patch_shape(ndim, patch_shape)
145        transform = torch_em.transform.Compose(
146            torch_em.transform.PadIfNecessary(adjusted_patch_shape), torch_em.transform.get_augmentations(2)
147        )
148    else:
149        transform = torch_em.transform.Compose(
150            torch_em.transform.PadIfNecessary(patch_shape), torch_em.transform.get_augmentations(3)
151        )
152
153    num_workers = loader_kwargs.pop("num_workers", 4 * batch_size)
154    shuffle = loader_kwargs.pop("shuffle", True)
155
156    if sampler is None:
157        sampler = torch_em.data.sampler.MinInstanceSampler(min_num_instances=4)
158
159    loader = torch_em.default_segmentation_loader(
160        data_paths, raw_key,
161        data_paths, label_key, sampler=sampler,
162        batch_size=batch_size, patch_shape=patch_shape, ndim=ndim,
163        is_seg_dataset=True, label_transform=label_transform, transform=transform,
164        num_workers=num_workers, shuffle=shuffle, n_samples=n_samples,
165        label_dtype=label_dtype, rois=rois, **loader_kwargs,
166    )
167    return loader

Get a dataloader for supervised segmentation training.

Arguments:
  • data_paths: The filepaths to the hdf5 files containing the training data.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • batch_size: The batch size for training.
  • n_samples: The number of samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • add_boundary_transform: Whether to add a boundary channel to the training data.
  • label_dtype: The datatype of the labels returned by the dataloader.
  • rois: Optional region of interests for training.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • loader_kwargs: Additional keyword arguments for the dataloader.
Returns:

The PyTorch dataloader.

def supervised_training( name: str, train_paths: Tuple[str], val_paths: Tuple[str], label_key: str, patch_shape: Tuple[int, int, int], save_root: Optional[str] = None, raw_key: str = 'raw', batch_size: int = 1, lr: float = 0.0001, n_iterations: int = 100000, train_rois: Optional[Tuple[Tuple[slice]]] = None, val_rois: Optional[Tuple[Tuple[slice]]] = None, sampler: Optional[<built-in function callable>] = None, n_samples_train: Optional[int] = None, n_samples_val: Optional[int] = None, check: bool = False, ignore_label: Optional[int] = None, label_transform: Optional[<built-in function callable>] = None, in_channels: int = 1, out_channels: int = 2, mask_channel: bool = False, **loader_kwargs):
170def supervised_training(
171    name: str,
172    train_paths: Tuple[str],
173    val_paths: Tuple[str],
174    label_key: str,
175    patch_shape: Tuple[int, int, int],
176    save_root: Optional[str] = None,
177    raw_key: str = "raw",
178    batch_size: int = 1,
179    lr: float = 1e-4,
180    n_iterations: int = int(1e5),
181    train_rois: Optional[Tuple[Tuple[slice]]] = None,
182    val_rois: Optional[Tuple[Tuple[slice]]] = None,
183    sampler: Optional[callable] = None,
184    n_samples_train: Optional[int] = None,
185    n_samples_val: Optional[int] = None,
186    check: bool = False,
187    ignore_label: Optional[int] = None,
188    label_transform: Optional[callable] = None,
189    in_channels: int = 1,
190    out_channels: int = 2,
191    mask_channel: bool = False,
192    **loader_kwargs,
193):
194    """Run supervised segmentation training.
195
196    This function trains a UNet for predicting outputs for segmentation.
197    Expects instance labels and converts them to boundary targets.
198    This behaviour can be changed by passing custom arguments for `label_transform`
199    and/or `out_channels`.
200
201    Args:
202        name: The name for the checkpoint to be trained.
203        train_paths: Filepaths to the hdf5 files for the training data.
204        val_paths: Filepaths to the df5 files for the validation data.
205        label_key: The key that holds the labels inside of the hdf5.
206        patch_shape: The patch shape used for a training example.
207            In order to run 2d training pass a patch shape with a singleton in the z-axis,
208            e.g. 'patch_shape = [1, 512, 512]'.
209        save_root: Folder where the checkpoint will be saved.
210        raw_key: The key that holds the raw data inside of the hdf5.
211        batch_size: The batch size for training.
212        lr: The initial learning rate.
213        n_iterations: The number of iterations to train for.
214        train_rois: Optional region of interests for training.
215        val_rois: Optional region of interests for validation.
216        sampler: Optional sampler for selecting blocks for training.
217            By default a minimum instance sampler will be used.
218        n_samples_train: The number of train samples per epoch. By default this will be estimated
219            based on the patch_shape and size of the volumes used for training.
220        n_samples_val: The number of val samples per epoch. By default this will be estimated
221            based on the patch_shape and size of the volumes used for validation.
222        check: Whether to check the training and validation loaders instead of running training.
223        ignore_label: Ignore label in the ground-truth. The areas marked by this label will be
224            ignored in the loss computation. By default this option is not used.
225        label_transform: Label transform that is applied to the segmentation to compute the targets.
226            If no label transform is passed (the default) a boundary transform is used.
227        out_channels: The number of output channels of the UNet.
228        mask_channel: Whether the last channels in the labels should be used for masking the loss.
229            This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
230        loader_kwargs: Additional keyword arguments for the dataloader.
231    """
232    train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
233                                         n_samples=n_samples_train, rois=train_rois, sampler=sampler,
234                                         ignore_label=ignore_label, label_transform=label_transform,
235                                         **loader_kwargs)
236    val_loader = get_supervised_loader(val_paths, raw_key, label_key, patch_shape, batch_size,
237                                       n_samples=n_samples_val, rois=val_rois, sampler=sampler,
238                                       ignore_label=ignore_label, label_transform=label_transform,
239                                       **loader_kwargs)
240
241    if check:
242        from torch_em.util.debug import check_loader
243        check_loader(train_loader, n_samples=4)
244        check_loader(val_loader, n_samples=4)
245        return
246
247    is_2d, _ = _determine_ndim(patch_shape)
248    if is_2d:
249        model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
250    else:
251        model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
252
253    loss, metric = None, None
254    # No ignore label -> we can use default loss.
255    if ignore_label is None and not mask_channel:
256        pass
257    # If we have an ignore label the loss and metric have to be modified
258    # so that the ignore mask is not used in the gradient calculation.
259    elif ignore_label is not None:
260        loss = torch_em.loss.LossWrapper(
261            loss=torch_em.loss.DiceLoss(),
262            transform=torch_em.loss.wrapper.MaskIgnoreLabel(
263                ignore_label=ignore_label, masking_method="multiply",
264            )
265        )
266        metric = loss
267    elif mask_channel:
268        loss = torch_em.loss.LossWrapper(
269            loss=torch_em.loss.DiceLoss(),
270            transform=torch_em.loss.wrapper.ApplyAndRemoveMask(
271                masking_method="crop" if out_channels == 1 else "multiply")
272        )
273        metric = loss
274    else:
275        raise ValueError
276
277    trainer = torch_em.default_segmentation_trainer(
278        name=name,
279        model=model,
280        train_loader=train_loader,
281        val_loader=val_loader,
282        learning_rate=lr,
283        mixed_precision=True,
284        log_image_interval=100,
285        compile_model=False,
286        save_root=save_root,
287        loss=loss,
288        metric=metric,
289    )
290    trainer.fit(n_iterations)

Run supervised segmentation training.

This function trains a UNet for predicting outputs for segmentation. Expects instance labels and converts them to boundary targets. This behaviour can be changed by passing custom arguments for label_transform and/or out_channels.

Arguments:
  • name: The name for the checkpoint to be trained.
  • train_paths: Filepaths to the hdf5 files for the training data.
  • val_paths: Filepaths to the df5 files for the validation data.
  • label_key: The key that holds the labels inside of the hdf5.
  • patch_shape: The patch shape used for a training example. In order to run 2d training pass a patch shape with a singleton in the z-axis, e.g. 'patch_shape = [1, 512, 512]'.
  • save_root: Folder where the checkpoint will be saved.
  • raw_key: The key that holds the raw data inside of the hdf5.
  • batch_size: The batch size for training.
  • lr: The initial learning rate.
  • n_iterations: The number of iterations to train for.
  • train_rois: Optional region of interests for training.
  • val_rois: Optional region of interests for validation.
  • sampler: Optional sampler for selecting blocks for training. By default a minimum instance sampler will be used.
  • n_samples_train: The number of train samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for training.
  • n_samples_val: The number of val samples per epoch. By default this will be estimated based on the patch_shape and size of the volumes used for validation.
  • check: Whether to check the training and validation loaders instead of running training.
  • ignore_label: Ignore label in the ground-truth. The areas marked by this label will be ignored in the loss computation. By default this option is not used.
  • label_transform: Label transform that is applied to the segmentation to compute the targets. If no label transform is passed (the default) a boundary transform is used.
  • out_channels: The number of output channels of the UNet.
  • mask_channel: Whether the last channels in the labels should be used for masking the loss. This can be used to implement more complex masking operations and is not compatible with ignore_label.
  • loader_kwargs: Additional keyword arguments for the dataloader.