
  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  9import imageio.v3 as imageio
 11import torch
 12from torch.optim import Optimizer
 13from torch.optim.lr_scheduler import _LRScheduler
 14from torch.utils.data import DataLoader, Dataset
 16import torch_em
 17from torch_em.data.datasets.util import split_kwargs
 19from elf.io import open_file
 22    from qtpy.QtCore import QObject
 23except Exception:
 24    QObject = Any
 26from ..util import get_device
 27from . import sam_trainer as trainers
 28from ..instance_segmentation import get_unetr
 29from . import joint_sam_trainer as joint_trainers
 30from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit
 33FilePath = Union[str, os.PathLike]
 36def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 37    x, _ = next(iter(loader))
 39    # Raw data: check that we have 1 or 3 channels.
 40    n_channels = x.shape[1]
 41    if n_channels not in (1, 3):
 42        raise ValueError(
 43            "Invalid number of channels for the input data from the data loader. "
 44            f"Expect 1 or 3 channels, got {n_channels}."
 45        )
 47    # Raw data: check that it is between [0, 255]
 48    minval, maxval = x.min(), x.max()
 49    if minval < 0 or minval > 255:
 50        raise ValueError(
 51            "Invalid data range for the input data from the data loader. "
 52            f"The input has to be in range [0, 255], but got minimum value {minval}."
 53        )
 54    if maxval < 1 or maxval > 255:
 55        raise ValueError(
 56            "Invalid data range for the input data from the data loader. "
 57            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 58        )
 60    # Target data: the check depends on whether we train with or without decoder.
 61    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 63    def _check_instance_channel(instance_channel):
 64        unique_vals = torch.unique(instance_channel)
 65        if (unique_vals < 0).any():
 66            raise ValueError(
 67                "The target channel with the instance segmentation must not have negative values."
 68            )
 69        if len(unique_vals) == 1:
 70            raise ValueError(
 71                "The target channel with the instance segmentation must have at least one instance."
 72            )
 73        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 74            raise ValueError(
 75                "All values in the target channel with the instance segmentation must be integer."
 76            )
 78    counter = 0
 79    name = "" if name is None else f"'{name}'"
 80    for x, y in tqdm(
 81        loader,
 82        desc=f"Verifying labels in {name} dataloader",
 83        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 84    ):
 85        n_channels_y = y.shape[1]
 86        if with_segmentation_decoder:
 87            if n_channels_y != 4:
 88                raise ValueError(
 89                    "Invalid number of channels in the target data from the data loader. "
 90                    "Expect 4 channel for training with an instance segmentation decoder, "
 91                    f"but got {n_channels_y} channels."
 92                )
 93            # Check instance channel per sample in a batch
 94            for per_y_sample in y:
 95                _check_instance_channel(per_y_sample[0])
 97            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
 98            if targets_min < 0 or targets_min > 1:
 99                raise ValueError(
100                    "Invalid value range in the target data from the value loader. "
101                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
102                    f"to be in range [0.0, 1.0], but got min {targets_min}"
103                )
104            if targets_max < 0 or targets_max > 1:
105                raise ValueError(
106                    "Invalid value range in the target data from the value loader. "
107                    "Expect the 3 last target channels (for normalized distances and foreground probabilities)"
108                    f"to be in range [0.0, 1.0], but got max {targets_max}"
109                )
111        else:
112            if n_channels_y != 1:
113                raise ValueError(
114                    "Invalid number of channels in the target data from the data loader. "
115                    "Expect 1 channel for training without an instance segmentation decoder,"
116                    f"but got {n_channels_y} channels."
117                )
118            # Check instance channel per sample in a batch
119            for per_y_sample in y:
120                _check_instance_channel(per_y_sample)
122        counter += 1
123        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
124            break
127# Make the progress bar callbacks compatible with a tqdm progress bar interface.
128class _ProgressBarWrapper:
129    def __init__(self, signals):
130        self._signals = signals
131        self._total = None
133    @property
134    def total(self):
135        return self._total
137    @total.setter
138    def total(self, value):
139        self._signals.pbar_total.emit(value)
140        self._total = value
142    def update(self, steps):
143        self._signals.pbar_update.emit(steps)
145    def set_description(self, desc, **kwargs):
146        self._signals.pbar_description.emit(desc)
149def _count_parameters(model_parameters):
150    params = sum(p.numel() for p in model_parameters if p.requires_grad)
151    params = params / 1e6
152    print(f"The number of trainable parameters for the provided model is {round(params, 2)}M")
156def _filter_warnings(ignore_warnings):
157    if ignore_warnings:
158        with warnings.catch_warnings():
159            warnings.simplefilter("ignore")
160            yield
161    else:
162        with nullcontext():
163            yield
166def train_sam(
167    name: str,
168    model_type: str,
169    train_loader: DataLoader,
170    val_loader: DataLoader,
171    n_epochs: int = 100,
172    early_stopping: Optional[int] = 10,
173    n_objects_per_batch: Optional[int] = 25,
174    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
175    with_segmentation_decoder: bool = True,
176    freeze: Optional[List[str]] = None,
177    device: Optional[Union[str, torch.device]] = None,
178    lr: float = 1e-5,
179    n_sub_iteration: int = 8,
180    save_root: Optional[Union[str, os.PathLike]] = None,
181    mask_prob: float = 0.5,
182    n_iterations: Optional[int] = None,
183    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
184    scheduler_kwargs: Optional[Dict[str, Any]] = None,
185    save_every_kth_epoch: Optional[int] = None,
186    pbar_signals: Optional[QObject] = None,
187    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
188    peft_kwargs: Optional[Dict] = None,
189    ignore_warnings: bool = True,
190    verify_n_labels_in_loader: Optional[int] = 50,
191    **model_kwargs,
192) -> None:
193    """Run training for a SAM model.
195    Args:
196        name: The name of the model to be trained.
197            The checkpoint and logs wil have this name.
198        model_type: The type of the SAM model.
199        train_loader: The dataloader for training.
200        val_loader: The dataloader for validation.
201        n_epochs: The number of epochs to train for.
202        early_stopping: Enable early stopping after this number of epochs
203            without improvement.
204        n_objects_per_batch: The number of objects per batch used to compute
205            the loss for interative segmentation. If None all objects will be used,
206            if given objects will be randomly sub-sampled.
207        checkpoint_path: Path to checkpoint for initializing the SAM model.
208        with_segmentation_decoder: Whether to train additional UNETR decoder
209            for automatic instance segmentation.
210        freeze: Specify parts of the model that should be frozen, namely:
211            image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
230            By default, 50 batches of labels are verified from the dataloaders.
231        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
232        ignore_warnings: Whether to ignore raised warnings.
233    """
234    with _filter_warnings(ignore_warnings):
236        t_start = time.time()
238        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
239        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
241        device = get_device(device)
242        # Get the trainable segment anything model.
243        model, state = get_trainable_sam_model(
244            model_type=model_type,
245            device=device,
246            freeze=freeze,
247            checkpoint_path=checkpoint_path,
248            return_state=True,
249            peft_kwargs=peft_kwargs,
250            **model_kwargs
251        )
252        # This class creates all the training data for a batch (inputs, prompts and labels).
253        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)
255        # Create the UNETR decoder (if train with it) and the optimizer.
256        if with_segmentation_decoder:
258            # Get the UNETR.
259            unetr = get_unetr(
260                image_encoder=model.sam.image_encoder,
261                decoder_state=state.get("decoder_state", None),
262                device=device,
263            )
265            # Get the parameters for SAM and the decoder from UNETR.
266            joint_model_params = [params for params in model.parameters()]  # sam parameters
267            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
268                if not param_name.startswith("encoder"):
269                    joint_model_params.append(params)
271            optimizer = optimizer_class(joint_model_params, lr=lr)
273        else:
274            optimizer = optimizer_class(model.parameters(), lr=lr)
276        if scheduler_kwargs is None:
277            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
279        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
281        # The trainer which performs training and validation.
282        if with_segmentation_decoder:
283            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
284            trainer = joint_trainers.JointSamTrainer(
285                name=name,
286                save_root=save_root,
287                train_loader=train_loader,
288                val_loader=val_loader,
289                model=model,
290                optimizer=optimizer,
291                device=device,
292                lr_scheduler=scheduler,
293                logger=joint_trainers.JointSamLogger,
294                log_image_interval=100,
295                mixed_precision=True,
296                convert_inputs=convert_inputs,
297                n_objects_per_batch=n_objects_per_batch,
298                n_sub_iteration=n_sub_iteration,
299                compile_model=False,
300                unetr=unetr,
301                instance_loss=instance_seg_loss,
302                instance_metric=instance_seg_loss,
303                early_stopping=early_stopping,
304                mask_prob=mask_prob,
305            )
306        else:
307            trainer = trainers.SamTrainer(
308                name=name,
309                train_loader=train_loader,
310                val_loader=val_loader,
311                model=model,
312                optimizer=optimizer,
313                device=device,
314                lr_scheduler=scheduler,
315                logger=trainers.SamLogger,
316                log_image_interval=100,
317                mixed_precision=True,
318                convert_inputs=convert_inputs,
319                n_objects_per_batch=n_objects_per_batch,
320                n_sub_iteration=n_sub_iteration,
321                compile_model=False,
322                early_stopping=early_stopping,
323                mask_prob=mask_prob,
324                save_root=save_root,
325            )
327        if n_iterations is None:
328            trainer_fit_params = {"epochs": n_epochs}
329        else:
330            trainer_fit_params = {"iterations": n_iterations}
332        if save_every_kth_epoch is not None:
333            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
335        if pbar_signals is not None:
336            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
337            trainer_fit_params["progress"] = progress_bar_wrapper
339        trainer.fit(**trainer_fit_params)
341        t_run = time.time() - t_start
342        hours = int(t_run // 3600)
343        minutes = int(t_run // 60)
344        seconds = int(round(t_run % 60, 0))
345        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
348def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
349    if isinstance(raw_paths, (str, os.PathLike)):
350        path = raw_paths
351    else:
352        path = raw_paths[0]
353    assert isinstance(path, (str, os.PathLike))
355    # Check the underlying data dimensionality.
356    if raw_key is None:  # If no key is given then we assume it's an image file.
357        ndim = imageio.imread(path).ndim
358    else:  # Otherwise we try to open the file from key.
359        try:  # First try to open it with elf.
360            with open_file(path, "r") as f:
361                ndim = f[raw_key].ndim
362        except ValueError:  # This may fail for images in a folder with different sizes.
363            # In that case we read one of the images.
364            image_path = glob(os.path.join(path, raw_key))[0]
365            ndim = imageio.imread(image_path).ndim
367    if ndim == 2:
368        assert len(patch_shape) == 2
369        return patch_shape
370    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
371        return (1,) + patch_shape
372    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
373        return (1,) + patch_shape
374    else:
375        return patch_shape
378def default_sam_dataset(
379    raw_paths: Union[List[FilePath], FilePath],
380    raw_key: Optional[str],
381    label_paths: Union[List[FilePath], FilePath],
382    label_key: Optional[str],
383    patch_shape: Tuple[int],
384    with_segmentation_decoder: bool,
385    with_channels: bool = False,
386    sampler: Optional[Callable] = None,
387    raw_transform: Optional[Callable] = None,
388    n_samples: Optional[int] = None,
389    is_train: bool = True,
390    min_size: int = 25,
391    max_sampling_attempts: Optional[int] = None,
392    is_seg_dataset: Optional[bool] = None,
393    **kwargs,
394) -> Dataset:
395    """Create a PyTorch Dataset for training a SAM model.
397    Args:
398        raw_paths: The path(s) to the image data used for training.
399            Can either be multiple 2D images or volumetric data.
400        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
401            or a glob pattern for selecting multiple files.
402        label_paths: The path(s) to the label data used for training.
403            Can either be multiple 2D images or volumetric data.
404        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
405            or a glob pattern for selecting multiple files.
406        patch_shape: The shape for training patches.
407        with_segmentation_decoder: Whether to train with additional segmentation decoder.
408        with_channels: Whether the image data has RGB channels.
409        sampler: A sampler to reject batches according to a given criterion.
410        raw_transform: Transformation applied to the image data.
411            If not given the data will be cast to 8bit.
412        n_samples: The number of samples for this dataset.
413        is_train: Whether this dataset is used for training or validation.
414        min_size: Minimal object size. Smaller objects will be filtered.
415        max_sampling_attempts: Number of sampling attempts to make from a dataset.
416        is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
417            or 'from torch_em.data import ImageCollectionDataset'
419    Returns:
420        The dataset.
421    """
423    # Set the data transformations.
424    if raw_transform is None:
425        raw_transform = require_8bit
427    if with_segmentation_decoder:
428        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
429            distances=True, boundary_distances=True, directed_distances=False,
430            foreground=True, instances=True, min_size=min_size,
431        )
432    else:
433        label_transform = torch_em.transform.label.MinSizeLabelTransform(
434            min_size=min_size
435        )
437    # Set a default sampler if none was passed.
438    if sampler is None:
439        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
441    # Check the patch shape to add a singleton if required.
442    patch_shape = _update_patch_shape(
443        patch_shape, raw_paths, raw_key, with_channels
444    )
446    # Set a minimum number of samples per epoch.
447    if n_samples is None:
448        loader = torch_em.default_segmentation_loader(
449            raw_paths, raw_key, label_paths, label_key, batch_size=1,
450            patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset,
451        )
452        n_samples = max(len(loader), 100 if is_train else 5)
454    dataset = torch_em.default_segmentation_dataset(
455        raw_paths, raw_key, label_paths, label_key,
456        patch_shape=patch_shape,
457        raw_transform=raw_transform, label_transform=label_transform,
458        with_channels=with_channels, ndim=2,
459        sampler=sampler, n_samples=n_samples,
460        is_seg_dataset=is_seg_dataset,
461        **kwargs,
462    )
464    if max_sampling_attempts is not None:
465        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
466            for ds in dataset.datasets:
467                ds.max_sampling_attempts = max_sampling_attempts
468        else:
469            dataset.max_sampling_attempts = max_sampling_attempts
471    return dataset
474def default_sam_loader(**kwargs) -> DataLoader:
475    ds_kwargs, loader_kwargs = split_kwargs(default_sam_dataset, **kwargs)
476    ds = default_sam_dataset(**ds_kwargs)
477    loader = torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
478    return loader
482    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
483    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
484    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
485    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
486    "V100": {"model_type": "vit_b"},
487    "A100": {"model_type": "vit_h"},
489"""Best training configurations for given hardware resources.
493def train_sam_for_configuration(
494    name: str,
495    configuration: str,
496    train_loader: DataLoader,
497    val_loader: DataLoader,
498    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
499    with_segmentation_decoder: bool = True,
500    model_type: Optional[str] = None,
501    **kwargs,
502) -> None:
503    """Run training for a SAM model with the configuration for a given hardware resource.
505    Selects the best training settings for the given configuration.
506    The available configurations are listed in `CONFIGURATIONS`.
508    Args:
509        name: The name of the model to be trained.
510            The checkpoint and logs wil have this name.
511        configuration: The configuration (= name of hardware resource).
512        train_loader: The dataloader for training.
513        val_loader: The dataloader for validation.
514        checkpoint_path: Path to checkpoint for initializing the SAM model.
515        with_segmentation_decoder: Whether to train additional UNETR decoder
516            for automatic instance segmentation.
517        model_type: Over-ride the default model type.
518            This can be used to use one of the micro_sam models as starting point
519            instead of a default sam model.
520        kwargs: Additional keyword parameterts that will be passed to `train_sam`.
521    """
522    if configuration in CONFIGURATIONS:
523        train_kwargs = CONFIGURATIONS[configuration]
524    else:
525        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
527    if model_type is None:
528        model_type = train_kwargs.pop("model_type")
529    else:
530        expected_model_type = train_kwargs.pop("model_type")
531        if model_type[:5] != expected_model_type:
532            warnings.warn("You have specified a different model type.")
534    train_kwargs.update(**kwargs)
535    train_sam(
536        name=name, train_loader=train_loader, val_loader=val_loader,
537        checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
538        model_type=model_type, **train_kwargs
539    )
FilePath = typing.Union[str, os.PathLike]
196    Args:
197        name: The name of the model to be trained.
198            The checkpoint and logs wil have this name.
199        model_type: The type of the SAM model.
200        train_loader: The dataloader for training.
201        val_loader: The dataloader for validation.
202        n_epochs: The number of epochs to train for.
203        early_stopping: Enable early stopping after this number of epochs
204            without improvement.
205        n_objects_per_batch: The number of objects per batch used to compute
206            the loss for interative segmentation. If None all objects will be used,
207            if given objects will be randomly sub-sampled.
208        checkpoint_path: Path to checkpoint for initializing the SAM model.
209        with_segmentation_decoder: Whether to train additional UNETR decoder
210            for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely:
212            image_encoder, prompt_encoder and mask_decoder
213            By default nothing is frozen and the full model is updated.
214        device: The device to use for training.
215        lr: The learning rate.
216        n_sub_iteration: The number of iterative prompts per training iteration.
217        save_root: Optional root directory for saving the checkpoints and logs.
218            If not given the current working directory is used.
219        mask_prob: The probability for using a mask as input in a given training sub-iteration.
220        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
221        scheduler_class: The learning rate scheduler to update the learning rate.
222            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
223        scheduler_kwargs: The learning rate scheduler parameters.
224            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
225        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
226        pbar_signals: Controls for napari progress bar.
227        optimizer_class: The optimizer class.
228            By default, torch.optim.AdamW is used.
229        peft_kwargs: Keyword arguments for the PEFT wrapper class.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
233        ignore_warnings: Whether to ignore raised warnings.
234    """
Run training for a SAM model.

  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
  • ignore_warnings: Whether to ignore raised warnings.
398    Args:
399        raw_paths: The path(s) to the image data used for training.
400            Can either be multiple 2D images or volumetric data.
401        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
402            or a glob pattern for selecting multiple files.
403        label_paths: The path(s) to the label data used for training.
404            Can either be multiple 2D images or volumetric data.
405        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
406            or a glob pattern for selecting multiple files.
407        patch_shape: The shape for training patches.
408        with_segmentation_decoder: Whether to train with additional segmentation decoder.
409        with_channels: Whether the image data has RGB channels.
410        sampler: A sampler to reject batches according to a given criterion.
411        raw_transform: Transformation applied to the image data.
412            If not given the data will be cast to 8bit.
413        n_samples: The number of samples for this dataset.
414        is_train: Whether this dataset is used for training or validation.
415        min_size: Minimal object size. Smaller objects will be filtered.
416        max_sampling_attempts: Number of sampling attempts to make from a dataset.
417        is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset'
418            or 'from torch_em.data import ImageCollectionDataset'
420    Returns:
421        The dataset.
Create a PyTorch Dataset for training a SAM model.

  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has RGB channels.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' or 'from torch_em.data import ImageCollectionDataset'

The dataset.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}

Best training configurations for given hardware resources.

506    Selects the best training settings for the given configuration.
507    The available configurations are listed in `CONFIGURATIONS`.
509    Args:
510        name: The name of the model to be trained.
511            The checkpoint and logs wil have this name.
512        configuration: The configuration (= name of hardware resource).
513        train_loader: The dataloader for training.
514        val_loader: The dataloader for validation.
515        checkpoint_path: Path to checkpoint for initializing the SAM model.
516        with_segmentation_decoder: Whether to train additional UNETR decoder
517            for automatic instance segmentation.
518        model_type: Over-ride the default model type.
519            This can be used to use one of the micro_sam models as starting point
520            instead of a default sam model.
521        kwargs: Additional keyword parameterts that will be passed to `train_sam`.
Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameterts that will be passed to train_sam.