
  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  9import imageio.v3 as imageio
 11import torch
 12from torch.optim import Optimizer
 13from torch.utils.data import random_split
 14from torch.utils.data import DataLoader, Dataset
 15from torch.optim.lr_scheduler import _LRScheduler
 17import torch_em
 18from torch_em.util import load_data
 19from torch_em.data.datasets.util import split_kwargs
 21from elf.io import open_file
 24    from qtpy.QtCore import QObject
 25except Exception:
 26    QObject = Any
 28from . import sam_trainer as trainers
 29from ..instance_segmentation import get_unetr
 30from . import joint_sam_trainer as joint_trainers
 31from ..util import get_device, get_model_names, export_custom_sam_model
 32from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
 35FilePath = Union[str, os.PathLike]
 38def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 39    x, _ = next(iter(loader))
 41    # Raw data: check that we have 1 or 3 channels.
 42    n_channels = x.shape[1]
 43    if n_channels not in (1, 3):
 44        raise ValueError(
 45            "Invalid number of channels for the input data from the data loader. "
 46            f"Expect 1 or 3 channels, got {n_channels}."
 47        )
 49    # Raw data: check that it is between [0, 255]
 50    minval, maxval = x.min(), x.max()
 51    if minval < 0 or minval > 255:
 52        raise ValueError(
 53            "Invalid data range for the input data from the data loader. "
 54            f"The input has to be in range [0, 255], but got minimum value {minval}."
 55        )
 56    if maxval < 1 or maxval > 255:
 57        raise ValueError(
 58            "Invalid data range for the input data from the data loader. "
 59            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 60        )
 62    # Target data: the check depends on whether we train with or without decoder.
 63    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 65    def _check_instance_channel(instance_channel):
 66        unique_vals = torch.unique(instance_channel)
 67        if (unique_vals < 0).any():
 68            raise ValueError(
 69                "The target channel with the instance segmentation must not have negative values."
 70            )
 71        if len(unique_vals) == 1:
 72            raise ValueError(
 73                "The target channel with the instance segmentation must have at least one instance."
 74            )
 75        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 76            raise ValueError(
 77                "All values in the target channel with the instance segmentation must be integer."
 78            )
 80    counter = 0
 81    name = "" if name is None else f"'{name}'"
 82    for x, y in tqdm(
 83        loader,
 84        desc=f"Verifying labels in {name} dataloader",
 85        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 86    ):
 87        n_channels_y = y.shape[1]
 88        if with_segmentation_decoder:
 89            if n_channels_y != 4:
 90                raise ValueError(
 91                    "Invalid number of channels in the target data from the data loader. "
 92                    "Expect 4 channel for training with an instance segmentation decoder, "
 93                    f"but got {n_channels_y} channels."
 94                )
 95            # Check instance channel per sample in a batch
 96            for per_y_sample in y:
 97                _check_instance_channel(per_y_sample[0])
 99            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
100            if targets_min < 0 or targets_min > 1:
101                raise ValueError(
102                    "Invalid value range in the target data from the value loader. "
103                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
104                    f"to be in range [0.0, 1.0], but got min {targets_min}"
105                )
106            if targets_max < 0 or targets_max > 1:
107                raise ValueError(
108                    "Invalid value range in the target data from the value loader. "
109                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
110                    f"to be in range [0.0, 1.0], but got max {targets_max}"
111                )
113        else:
114            if n_channels_y != 1:
115                raise ValueError(
116                    "Invalid number of channels in the target data from the data loader. "
117                    "Expect 1 channel for training without an instance segmentation decoder, "
118                    f"but got {n_channels_y} channels."
119                )
120            # Check instance channel per sample in a batch
121            for per_y_sample in y:
122                _check_instance_channel(per_y_sample)
124        counter += 1
125        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
126            break
129# Make the progress bar callbacks compatible with a tqdm progress bar interface.
130class _ProgressBarWrapper:
131    def __init__(self, signals):
132        self._signals = signals
133        self._total = None
135    @property
136    def total(self):
137        return self._total
139    @total.setter
140    def total(self, value):
141        self._signals.pbar_total.emit(value)
142        self._total = value
144    def update(self, steps):
145        self._signals.pbar_update.emit(steps)
147    def set_description(self, desc, **kwargs):
148        self._signals.pbar_description.emit(desc)
152def _filter_warnings(ignore_warnings):
153    if ignore_warnings:
154        with warnings.catch_warnings():
155            warnings.simplefilter("ignore")
156            yield
157    else:
158        with nullcontext():
159            yield
162def _count_parameters(model_parameters):
163    params = sum(p.numel() for p in model_parameters if p.requires_grad)
164    params = params / 1e6
165    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
168def train_sam(
169    name: str,
170    model_type: str,
171    train_loader: DataLoader,
172    val_loader: DataLoader,
173    n_epochs: int = 100,
174    early_stopping: Optional[int] = 10,
175    n_objects_per_batch: Optional[int] = 25,
176    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
177    with_segmentation_decoder: bool = True,
178    freeze: Optional[List[str]] = None,
179    device: Optional[Union[str, torch.device]] = None,
180    lr: float = 1e-5,
181    n_sub_iteration: int = 8,
182    save_root: Optional[Union[str, os.PathLike]] = None,
183    mask_prob: float = 0.5,
184    n_iterations: Optional[int] = None,
185    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
186    scheduler_kwargs: Optional[Dict[str, Any]] = None,
187    save_every_kth_epoch: Optional[int] = None,
188    pbar_signals: Optional[QObject] = None,
189    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
190    peft_kwargs: Optional[Dict] = None,
191    ignore_warnings: bool = True,
192    verify_n_labels_in_loader: Optional[int] = 50,
193    box_distortion_factor: Optional[float] = 0.025,
194    **model_kwargs,
195) -> None:
196    """Run training for a SAM model.
198    Args:
199        name: The name of the model to be trained. The checkpoint and logs will have this name.
200        model_type: The type of the SAM model.
201        train_loader: The dataloader for training.
202        val_loader: The dataloader for validation.
203        n_epochs: The number of epochs to train for.
204        early_stopping: Enable early stopping after this number of epochs without improvement.
205        n_objects_per_batch: The number of objects per batch used to compute
206            the loss for interative segmentation. If None all objects will be used,
207            if given objects will be randomly sub-sampled.
208        checkpoint_path: Path to checkpoint for initializing the SAM model.
209        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
210        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
211            By default nothing is frozen and the full model is updated.
212        device: The device to use for training.
213        lr: The learning rate.
214        n_sub_iteration: The number of iterative prompts per training iteration.
215        save_root: Optional root directory for saving the checkpoints and logs.
216            If not given the current working directory is used.
217        mask_prob: The probability for using a mask as input in a given training sub-iteration.
218        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
219        scheduler_class: The learning rate scheduler to update the learning rate.
220            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
221        scheduler_kwargs: The learning rate scheduler parameters.
222            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
223        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
224        pbar_signals: Controls for napari progress bar.
225        optimizer_class: The optimizer class.
226            By default, torch.optim.AdamW is used.
227        peft_kwargs: Keyword arguments for the PEFT wrapper class.
228        ignore_warnings: Whether to ignore raised warnings.
229        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
230            By default, 50 batches of labels are verified from the dataloaders.
231        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
232        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
233    """
234    with _filter_warnings(ignore_warnings):
236        t_start = time.time()
238        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
239        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
241        device = get_device(device)
242        # Get the trainable segment anything model.
243        model, state = get_trainable_sam_model(
244            model_type=model_type,
245            device=device,
246            freeze=freeze,
247            checkpoint_path=checkpoint_path,
248            return_state=True,
249            peft_kwargs=peft_kwargs,
250            **model_kwargs
251        )
253        # This class creates all the training data for a batch (inputs, prompts and labels).
254        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
256        # Create the UNETR decoder (if train with it) and the optimizer.
257        if with_segmentation_decoder:
259            # Get the UNETR.
260            unetr = get_unetr(
261                image_encoder=model.sam.image_encoder,
262                decoder_state=state.get("decoder_state", None),
263                device=device,
264            )
266            # Get the parameters for SAM and the decoder from UNETR.
267            joint_model_params = [params for params in model.parameters()]  # sam parameters
268            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
269                if not param_name.startswith("encoder"):
270                    joint_model_params.append(params)
272            model_params = joint_model_params
273        else:
274            model_params = model.parameters()
276        optimizer = optimizer_class(model_params, lr=lr)
278        if scheduler_kwargs is None:
279            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
281        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
283        # The trainer which performs training and validation.
284        if with_segmentation_decoder:
285            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
286            trainer = joint_trainers.JointSamTrainer(
287                name=name,
288                save_root=save_root,
289                train_loader=train_loader,
290                val_loader=val_loader,
291                model=model,
292                optimizer=optimizer,
293                device=device,
294                lr_scheduler=scheduler,
295                logger=joint_trainers.JointSamLogger,
296                log_image_interval=100,
297                mixed_precision=True,
298                convert_inputs=convert_inputs,
299                n_objects_per_batch=n_objects_per_batch,
300                n_sub_iteration=n_sub_iteration,
301                compile_model=False,
302                unetr=unetr,
303                instance_loss=instance_seg_loss,
304                instance_metric=instance_seg_loss,
305                early_stopping=early_stopping,
306                mask_prob=mask_prob,
307            )
308        else:
309            trainer = trainers.SamTrainer(
310                name=name,
311                train_loader=train_loader,
312                val_loader=val_loader,
313                model=model,
314                optimizer=optimizer,
315                device=device,
316                lr_scheduler=scheduler,
317                logger=trainers.SamLogger,
318                log_image_interval=100,
319                mixed_precision=True,
320                convert_inputs=convert_inputs,
321                n_objects_per_batch=n_objects_per_batch,
322                n_sub_iteration=n_sub_iteration,
323                compile_model=False,
324                early_stopping=early_stopping,
325                mask_prob=mask_prob,
326                save_root=save_root,
327            )
329        if n_iterations is None:
330            trainer_fit_params = {"epochs": n_epochs}
331        else:
332            trainer_fit_params = {"iterations": n_iterations}
334        if save_every_kth_epoch is not None:
335            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
337        if pbar_signals is not None:
338            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
339            trainer_fit_params["progress"] = progress_bar_wrapper
341        trainer.fit(**trainer_fit_params)
343        t_run = time.time() - t_start
344        hours = int(t_run // 3600)
345        minutes = int(t_run // 60)
346        seconds = int(round(t_run % 60, 0))
347        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
350def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
351    if isinstance(raw_paths, (str, os.PathLike)):
352        path = raw_paths
353    else:
354        path = raw_paths[0]
355    assert isinstance(path, (str, os.PathLike))
357    # Check the underlying data dimensionality.
358    if raw_key is None:  # If no key is given then we assume it's an image file.
359        ndim = imageio.imread(path).ndim
360    else:  # Otherwise we try to open the file from key.
361        try:  # First try to open it with elf.
362            with open_file(path, "r") as f:
363                ndim = f[raw_key].ndim
364        except ValueError:  # This may fail for images in a folder with different sizes.
365            # In that case we read one of the images.
366            image_path = glob(os.path.join(path, raw_key))[0]
367            ndim = imageio.imread(image_path).ndim
369    if not isinstance(patch_shape, tuple):
370        patch_shape = tuple(patch_shape)
372    if ndim == 2:
373        assert len(patch_shape) == 2
374        return patch_shape
375    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
376        return (1,) + patch_shape
377    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
378        return (1,) + patch_shape
379    else:
380        return patch_shape
383def default_sam_dataset(
384    raw_paths: Union[List[FilePath], FilePath],
385    raw_key: Optional[str],
386    label_paths: Union[List[FilePath], FilePath],
387    label_key: Optional[str],
388    patch_shape: Tuple[int],
389    with_segmentation_decoder: bool,
390    with_channels: Optional[bool] = None,
391    sampler: Optional[Callable] = None,
392    raw_transform: Optional[Callable] = None,
393    n_samples: Optional[int] = None,
394    is_train: bool = True,
395    min_size: int = 25,
396    max_sampling_attempts: Optional[int] = None,
397    **kwargs,
398) -> Dataset:
399    """Create a PyTorch Dataset for training a SAM model.
401    Args:
402        raw_paths: The path(s) to the image data used for training.
403            Can either be multiple 2D images or volumetric data.
404        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
405            or a glob pattern for selecting multiple files.
406        label_paths: The path(s) to the label data used for training.
407            Can either be multiple 2D images or volumetric data.
408        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
409            or a glob pattern for selecting multiple files.
410        patch_shape: The shape for training patches.
411        with_segmentation_decoder: Whether to train with additional segmentation decoder.
412        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
413        sampler: A sampler to reject batches according to a given criterion.
414        raw_transform: Transformation applied to the image data.
415            If not given the data will be cast to 8bit.
416        n_samples: The number of samples for this dataset.
417        is_train: Whether this dataset is used for training or validation.
418        min_size: Minimal object size. Smaller objects will be filtered.
419        max_sampling_attempts: Number of sampling attempts to make from a dataset.
420        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
422    Returns:
423        The segmentation dataset.
424    """
426    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
427    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
429    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
430    # Get valid raw paths to make checks possible.
431    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
432        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
433    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
434        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
436    # Load one of the raw inputs to validate whether it is RGB or not.
437    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
438    if test_raw_inputs.ndim == 3:
439        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
440            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
441            # We need to provide a list of inputs to 'ImageCollectionDataset'.
442            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
443            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
445            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
446            with_channels = False if with_channels is None else with_channels
448        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
449            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
450            with_channels = True if with_channels is None else with_channels
452    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
453    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
454    with_channels = False if with_channels is None else with_channels
456    # Set the data transformations.
457    if raw_transform is None:
458        raw_transform = require_8bit
460    if with_segmentation_decoder:
461        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
462            distances=True,
463            boundary_distances=True,
464            directed_distances=False,
465            foreground=True,
466            instances=True,
467            min_size=min_size,
468        )
469    else:
470        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
472    # Set a default sampler if none was passed.
473    if sampler is None:
474        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
476    # Check the patch shape to add a singleton if required.
477    patch_shape = _update_patch_shape(
478        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
479    )
481    # Set a minimum number of samples per epoch.
482    if n_samples is None:
483        loader = torch_em.default_segmentation_loader(
484            raw_paths=raw_paths,
485            raw_key=raw_key,
486            label_paths=label_paths,
487            label_key=label_key,
488            batch_size=1,
489            patch_shape=patch_shape,
490            with_channels=with_channels,
491            ndim=2,
492            is_seg_dataset=is_seg_dataset,
493            raw_transform=raw_transform,
494            **kwargs
495        )
496        n_samples = max(len(loader), 100 if is_train else 5)
498    dataset = torch_em.default_segmentation_dataset(
499        raw_paths=raw_paths,
500        raw_key=raw_key,
501        label_paths=label_paths,
502        label_key=label_key,
503        patch_shape=patch_shape,
504        raw_transform=raw_transform,
505        label_transform=label_transform,
506        with_channels=with_channels,
507        ndim=2,
508        sampler=sampler,
509        n_samples=n_samples,
510        is_seg_dataset=is_seg_dataset,
511        **kwargs,
512    )
514    if max_sampling_attempts is not None:
515        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
516            for ds in dataset.datasets:
517                ds.max_sampling_attempts = max_sampling_attempts
518        else:
519            dataset.max_sampling_attempts = max_sampling_attempts
521    return dataset
524def default_sam_loader(**kwargs) -> DataLoader:
525    """Create a PyTorch DataLoader for training a SAM model.
527    Args:
528        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
530    Returns:
531        The DataLoader.
532    """
533    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
535    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
536    # which the users can provide to get their desired segmentation dataset.
537    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
538    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
540    ds = default_sam_dataset(**ds_kwargs)
541    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
545    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
546    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
547    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
548    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
549    "V100": {"model_type": "vit_b"},
550    "A100": {"model_type": "vit_h"},
554def _find_best_configuration():
555    if torch.cuda.is_available():
557        # Check how much memory we have and select the best matching GPU
558        # for the available VRAM size.
559        _, vram = torch.cuda.mem_get_info()
560        vram = vram / 1e9  # in GB
562        # Maybe we can get more configurations in the future.
563        if vram > 80:  # More than 80 GB: use the A100 configurations.
564            return "A100"
565        elif vram > 30:  # More than 30 GB: use the V100 configurations.
566            return "V100"
567        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
568            return "rtx5000"
569        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
570            return "CPU"
571    else:
572        return "CPU"
575"""Best training configurations for given hardware resources.
579def train_sam_for_configuration(
580    name: str,
581    configuration: str,
582    train_loader: DataLoader,
583    val_loader: DataLoader,
584    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
585    with_segmentation_decoder: bool = True,
586    model_type: Optional[str] = None,
587    **kwargs,
588) -> None:
589    """Run training for a SAM model with the configuration for a given hardware resource.
591    Selects the best training settings for the given configuration.
592    The available configurations are listed in `CONFIGURATIONS`.
594    Args:
595        name: The name of the model to be trained.
596            The checkpoint and logs wil have this name.
597        configuration: The configuration (= name of hardware resource).
598        train_loader: The dataloader for training.
599        val_loader: The dataloader for validation.
600        checkpoint_path: Path to checkpoint for initializing the SAM model.
601        with_segmentation_decoder: Whether to train additional UNETR decoder
602            for automatic instance segmentation.
603        model_type: Over-ride the default model type.
604            This can be used to use one of the micro_sam models as starting point
605            instead of a default sam model.
606        kwargs: Additional keyword parameters that will be passed to `train_sam`.
607    """
608    if configuration in CONFIGURATIONS:
609        train_kwargs = CONFIGURATIONS[configuration]
610    else:
611        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
613    if model_type is None:
614        model_type = train_kwargs.pop("model_type")
615    else:
616        expected_model_type = train_kwargs.pop("model_type")
617        if model_type[:5] != expected_model_type:
618            warnings.warn("You have specified a different model type.")
620    train_kwargs.update(**kwargs)
621    train_sam(
622        name=name,
623        train_loader=train_loader,
624        val_loader=val_loader,
625        checkpoint_path=checkpoint_path,
626        with_segmentation_decoder=with_segmentation_decoder,
627        model_type=model_type,
628        **train_kwargs
629    )
632def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
634    # Whether the model is stored in the current working directory or in another location.
635    if save_root is None:
636        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
638    # Get the 'best' model checkpoint ready for export.
639    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
640    if not os.path.exists(best_checkpoint):
641        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
643    # Export the model if an output path has been given.
644    if output_path:
646        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
647        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
648            export_custom_sam_model(
649                checkpoint_path=best_checkpoint,
650                model_type=model_type[:5],
651                save_path=output_path,
652                with_segmentation_decoder=with_segmentation_decoder,
653            )
655        # Otherwise we export it as bioimage.io model.
656        else:
657            from micro_sam.bioimageio import export_sam_model
659            # Load image and corresponding labels from the val loader.
660            with torch.no_grad():
661                image_data, label_data = next(iter(val_loader))
662                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
664                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
665                if label_data.ndim == 3:
666                    label_data = label_data[0]  # Gets the channel with instances.
667                assert image_data.shape == label_data.shape
668                label_data = label_data.astype("uint32")
670                export_sam_model(
671                    image=image_data,
672                    label_image=label_data,
673                    model_type=model_type[:5],
674                    name=checkpoint_name,
675                    output_path=output_path,
676                    checkpoint_path=best_checkpoint,
677                )
679        # The final path where the model has been stored.
680        final_path = output_path
682    else:  # If no exports have been made, inform the user about the best checkpoint.
683        final_path = best_checkpoint
685    return final_path
688def main():
689    """@private"""
690    import argparse
692    available_models = list(get_model_names())
693    available_models = ", ".join(available_models)
695    available_configurations = list(CONFIGURATIONS.keys())
696    available_configurations = ", ".join(available_configurations)
698    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
700    # Images and labels for training.
701    parser.add_argument(
702        "--images", required=True, type=str, nargs="*",
703        help="Filepath to images or the directory where the image data is stored."
704    )
705    parser.add_argument(
706        "--labels", required=True, type=str, nargs="*",
707        help="Filepath to ground-truth labels or the directory where the label data is stored."
708    )
709    parser.add_argument(
710        "--image_key", type=str, default=None,
711        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
712    )
713    parser.add_argument(
714        "--label_key", type=str, default=None,
715        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
716    )
718    # Images and labels for validation.
719    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
720    # Users can choose to have their explicit validation set via this feature as well.
721    parser.add_argument(
722        "--val_images", type=str, nargs="*",
723        help="Filepath to images for validation or the directory where the image data is stored."
724    )
725    parser.add_argument(
726        "--val_labels", type=str, nargs="*",
727        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
728    )
729    parser.add_argument(
730        "--val_image_key", type=str, default=None,
731        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
732    )
733    parser.add_argument(
734        "--val_label_key", type=str, default=None,
735        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
736    )
738    # Other necessary stuff for training.
739    parser.add_argument(
740        "--configuration", type=str, default=_find_best_configuration(),
741        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
742    )
743    parser.add_argument(
744        "--segmentation_decoder", type=str, default="instances",  # TODO: in future, we can extend this to semantic seg.
745        help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. "
746        "By default, it trains with the additional segmentation decoder for instance segmentation."
747    )
749    # Optional advanced settings a user can opt to change the values for.
750    parser.add_argument(
751        "-d", "--device", type=str, default=None,
752        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
753        "By default the most performant available device will be selected."
754    )
755    parser.add_argument(
756        "--patch_shape", type=int, nargs="*", default=(512, 512),
757        help="The choice of patch shape for training Segment Anything."
758    )
759    parser.add_argument(
760        "-m", "--model_type", type=str, default=None,
761        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
762    )
763    parser.add_argument(
764        "--checkpoint_path", type=str, default=None,
765        help="Checkpoint from which the SAM model will be loaded for finetuning."
766    )
767    parser.add_argument(
768        "-s", "--save_root", type=str, default=None,
769        help="The directory where the trained models and corresponding logs will be stored. "
770        "By default, there are stored in your current working directory."
771    )
772    parser.add_argument(
773        "--trained_model_name", type=str, default="sam_model",
774        help="The custom name of trained model. Allows users to have several trained models under the same 'save_root'."
775    )
776    parser.add_argument(
777        "--output_path", type=str, default=None,
778        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
779    )
780    parser.add_argument(
781        "--n_epochs", type=int, default=100,
782        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
783    )
784    parser.add_argument(
785        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
786    )
787    parser.add_argument(
788        "--batch_size", type=int, default=1,
789        help="The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1."
790    )
791    parser.add_argument(
792        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
793        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
794        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
795    )
797    args = parser.parse_args()
799    # 1. Get all necessary stuff for training.
800    checkpoint_name = args.trained_model_name
801    config = args.configuration
802    model_type = args.model_type
803    checkpoint_path = args.checkpoint_path
804    batch_size = args.batch_size
805    patch_shape = args.patch_shape
806    epochs = args.n_epochs
807    num_workers = args.num_workers
808    device = args.device
809    save_root = args.save_root
810    output_path = args.output_path
811    with_segmentation_decoder = (args.segmentation_decoder == "instances")
813    # Get image paths and corresponding keys.
814    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
815    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
817    # 2. Prepare the dataloaders.
819    # If the user wants to preprocess the inputs, we allow the possibility to do so.
820    _raw_transform = get_raw_transform(args.preprocess)
822    # Get the dataset with files for training.
823    dataset = default_sam_dataset(
824        raw_paths=train_images,
825        raw_key=train_image_key,
826        label_paths=train_gt,
827        label_key=train_gt_key,
828        patch_shape=patch_shape,
829        with_segmentation_decoder=with_segmentation_decoder,
830        raw_transform=_raw_transform,
831    )
833    # If val images are not exclusively provided, we create a val split from the training data.
834    if val_images is None:
835        assert val_gt is None and val_image_key is None and val_gt_key is None
836        # Use 10% of the dataset for validation - at least one image - for validation.
837        n_val = max(1, int(0.1 * len(dataset)))
838        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
840    else:  # If val images provided, we create a new dataset for it.
841        train_dataset = dataset
842        val_dataset = default_sam_dataset(
843            raw_paths=val_images,
844            raw_key=val_image_key,
845            label_paths=val_gt,
846            label_key=val_gt_key,
847            patch_shape=patch_shape,
848            with_segmentation_decoder=with_segmentation_decoder,
849            raw_transform=_raw_transform,
850        )
852    # Get the dataloaders from the datasets.
853    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
854    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
856    # 3. Train the Segment Anything Model.
858    # Get a valid model and other necessary parameters for training.
859    if model_type is not None and model_type not in available_models:
860        raise ValueError(f"'{model_type}' is not a valid choice of model.")
861    if config is not None and config not in available_configurations:
862        raise ValueError(f"'{config}' is not a valid choice of configuration.")
864    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
865        model_type = CONFIGURATIONS[config]["model_type"]
867    train_sam_for_configuration(
868        name=checkpoint_name,
869        configuration=config,
870        model_type=model_type,
871        train_loader=train_loader,
872        val_loader=val_loader,
873        n_epochs=epochs,
874        checkpoint_path=checkpoint_path,
875        with_segmentation_decoder=with_segmentation_decoder,
876        freeze=None,  # TODO: Allow for PEFT.
877        device=device,
878        save_root=save_root,
879        peft_kwargs=None,  # TODO: Allow for PEFT.
880    )
882    # 4. Export the model, if desired by the user
883    final_path = _export_helper(
884        save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader
885    )
887    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, **model_kwargs) -> None:
169def train_sam(
170    name: str,
171    model_type: str,
172    train_loader: DataLoader,
173    val_loader: DataLoader,
174    n_epochs: int = 100,
175    early_stopping: Optional[int] = 10,
176    n_objects_per_batch: Optional[int] = 25,
177    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
178    with_segmentation_decoder: bool = True,
179    freeze: Optional[List[str]] = None,
180    device: Optional[Union[str, torch.device]] = None,
181    lr: float = 1e-5,
182    n_sub_iteration: int = 8,
183    save_root: Optional[Union[str, os.PathLike]] = None,
184    mask_prob: float = 0.5,
185    n_iterations: Optional[int] = None,
186    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
187    scheduler_kwargs: Optional[Dict[str, Any]] = None,
188    save_every_kth_epoch: Optional[int] = None,
189    pbar_signals: Optional[QObject] = None,
190    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
191    peft_kwargs: Optional[Dict] = None,
192    ignore_warnings: bool = True,
193    verify_n_labels_in_loader: Optional[int] = 50,
194    box_distortion_factor: Optional[float] = 0.025,
195    **model_kwargs,
196) -> None:
197    """Run training for a SAM model.
199    Args:
200        name: The name of the model to be trained. The checkpoint and logs will have this name.
201        model_type: The type of the SAM model.
202        train_loader: The dataloader for training.
203        val_loader: The dataloader for validation.
204        n_epochs: The number of epochs to train for.
205        early_stopping: Enable early stopping after this number of epochs without improvement.
206        n_objects_per_batch: The number of objects per batch used to compute
207            the loss for interative segmentation. If None all objects will be used,
208            if given objects will be randomly sub-sampled.
209        checkpoint_path: Path to checkpoint for initializing the SAM model.
210        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        ignore_warnings: Whether to ignore raised warnings.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
233        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
234    """
235    with _filter_warnings(ignore_warnings):
237        t_start = time.time()
239        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
240        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
242        device = get_device(device)
243        # Get the trainable segment anything model.
244        model, state = get_trainable_sam_model(
245            model_type=model_type,
246            device=device,
247            freeze=freeze,
248            checkpoint_path=checkpoint_path,
249            return_state=True,
250            peft_kwargs=peft_kwargs,
251            **model_kwargs
252        )
254        # This class creates all the training data for a batch (inputs, prompts and labels).
255        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
257        # Create the UNETR decoder (if train with it) and the optimizer.
258        if with_segmentation_decoder:
260            # Get the UNETR.
261            unetr = get_unetr(
262                image_encoder=model.sam.image_encoder,
263                decoder_state=state.get("decoder_state", None),
264                device=device,
265            )
267            # Get the parameters for SAM and the decoder from UNETR.
268            joint_model_params = [params for params in model.parameters()]  # sam parameters
269            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
270                if not param_name.startswith("encoder"):
271                    joint_model_params.append(params)
273            model_params = joint_model_params
274        else:
275            model_params = model.parameters()
277        optimizer = optimizer_class(model_params, lr=lr)
279        if scheduler_kwargs is None:
280            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
282        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
284        # The trainer which performs training and validation.
285        if with_segmentation_decoder:
286            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
287            trainer = joint_trainers.JointSamTrainer(
288                name=name,
289                save_root=save_root,
290                train_loader=train_loader,
291                val_loader=val_loader,
292                model=model,
293                optimizer=optimizer,
294                device=device,
295                lr_scheduler=scheduler,
296                logger=joint_trainers.JointSamLogger,
297                log_image_interval=100,
298                mixed_precision=True,
299                convert_inputs=convert_inputs,
300                n_objects_per_batch=n_objects_per_batch,
301                n_sub_iteration=n_sub_iteration,
302                compile_model=False,
303                unetr=unetr,
304                instance_loss=instance_seg_loss,
305                instance_metric=instance_seg_loss,
306                early_stopping=early_stopping,
307                mask_prob=mask_prob,
308            )
309        else:
310            trainer = trainers.SamTrainer(
311                name=name,
312                train_loader=train_loader,
313                val_loader=val_loader,
314                model=model,
315                optimizer=optimizer,
316                device=device,
317                lr_scheduler=scheduler,
318                logger=trainers.SamLogger,
319                log_image_interval=100,
320                mixed_precision=True,
321                convert_inputs=convert_inputs,
322                n_objects_per_batch=n_objects_per_batch,
323                n_sub_iteration=n_sub_iteration,
324                compile_model=False,
325                early_stopping=early_stopping,
326                mask_prob=mask_prob,
327                save_root=save_root,
328            )
330        if n_iterations is None:
331            trainer_fit_params = {"epochs": n_epochs}
332        else:
333            trainer_fit_params = {"iterations": n_iterations}
335        if save_every_kth_epoch is not None:
336            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
338        if pbar_signals is not None:
339            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
340            trainer_fit_params["progress"] = progress_bar_wrapper
342        trainer.fit(**trainer_fit_params)
344        t_run = time.time() - t_start
345        hours = int(t_run // 3600)
346        minutes = int(t_run // 60)
347        seconds = int(round(t_run % 60, 0))
348        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
384def default_sam_dataset(
385    raw_paths: Union[List[FilePath], FilePath],
386    raw_key: Optional[str],
387    label_paths: Union[List[FilePath], FilePath],
388    label_key: Optional[str],
389    patch_shape: Tuple[int],
390    with_segmentation_decoder: bool,
391    with_channels: Optional[bool] = None,
392    sampler: Optional[Callable] = None,
393    raw_transform: Optional[Callable] = None,
394    n_samples: Optional[int] = None,
395    is_train: bool = True,
396    min_size: int = 25,
397    max_sampling_attempts: Optional[int] = None,
398    **kwargs,
399) -> Dataset:
400    """Create a PyTorch Dataset for training a SAM model.
402    Args:
403        raw_paths: The path(s) to the image data used for training.
404            Can either be multiple 2D images or volumetric data.
405        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
406            or a glob pattern for selecting multiple files.
407        label_paths: The path(s) to the label data used for training.
408            Can either be multiple 2D images or volumetric data.
409        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
410            or a glob pattern for selecting multiple files.
411        patch_shape: The shape for training patches.
412        with_segmentation_decoder: Whether to train with additional segmentation decoder.
413        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
414        sampler: A sampler to reject batches according to a given criterion.
415        raw_transform: Transformation applied to the image data.
416            If not given the data will be cast to 8bit.
417        n_samples: The number of samples for this dataset.
418        is_train: Whether this dataset is used for training or validation.
419        min_size: Minimal object size. Smaller objects will be filtered.
420        max_sampling_attempts: Number of sampling attempts to make from a dataset.
421        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
423    Returns:
424        The segmentation dataset.
425    """
427    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
428    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
430    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
431    # Get valid raw paths to make checks possible.
432    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
433        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
434    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
435        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
437    # Load one of the raw inputs to validate whether it is RGB or not.
438    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
439    if test_raw_inputs.ndim == 3:
440        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
441            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
442            # We need to provide a list of inputs to 'ImageCollectionDataset'.
443            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
444            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
446            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
447            with_channels = False if with_channels is None else with_channels
449        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
450            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
451            with_channels = True if with_channels is None else with_channels
453    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
454    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
455    with_channels = False if with_channels is None else with_channels
457    # Set the data transformations.
458    if raw_transform is None:
459        raw_transform = require_8bit
461    if with_segmentation_decoder:
462        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
463            distances=True,
464            boundary_distances=True,
465            directed_distances=False,
466            foreground=True,
467            instances=True,
468            min_size=min_size,
469        )
470    else:
471        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
473    # Set a default sampler if none was passed.
474    if sampler is None:
475        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
477    # Check the patch shape to add a singleton if required.
478    patch_shape = _update_patch_shape(
479        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
480    )
482    # Set a minimum number of samples per epoch.
483    if n_samples is None:
484        loader = torch_em.default_segmentation_loader(
485            raw_paths=raw_paths,
486            raw_key=raw_key,
487            label_paths=label_paths,
488            label_key=label_key,
489            batch_size=1,
490            patch_shape=patch_shape,
491            with_channels=with_channels,
492            ndim=2,
493            is_seg_dataset=is_seg_dataset,
494            raw_transform=raw_transform,
495            **kwargs
496        )
497        n_samples = max(len(loader), 100 if is_train else 5)
499    dataset = torch_em.default_segmentation_dataset(
500        raw_paths=raw_paths,
501        raw_key=raw_key,
502        label_paths=label_paths,
503        label_key=label_key,
504        patch_shape=patch_shape,
505        raw_transform=raw_transform,
506        label_transform=label_transform,
507        with_channels=with_channels,
508        ndim=2,
509        sampler=sampler,
510        n_samples=n_samples,
511        is_seg_dataset=is_seg_dataset,
512        **kwargs,
513    )
515    if max_sampling_attempts is not None:
516        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
517            for ds in dataset.datasets:
518                ds.max_sampling_attempts = max_sampling_attempts
519        else:
520            dataset.max_sampling_attempts = max_sampling_attempts
522    return dataset

Create a PyTorch Dataset for training a SAM model.

  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
525def default_sam_loader(**kwargs) -> DataLoader:
526    """Create a PyTorch DataLoader for training a SAM model.
528    Args:
529        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
531    Returns:
532        The DataLoader.
533    """
534    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
536    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
537    # which the users can provide to get their desired segmentation dataset.
538    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
539    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
541    ds = default_sam_dataset(**ds_kwargs)
542    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
def train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
580def train_sam_for_configuration(
581    name: str,
582    configuration: str,
583    train_loader: DataLoader,
584    val_loader: DataLoader,
585    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
586    with_segmentation_decoder: bool = True,
587    model_type: Optional[str] = None,
588    **kwargs,
589) -> None:
590    """Run training for a SAM model with the configuration for a given hardware resource.
592    Selects the best training settings for the given configuration.
593    The available configurations are listed in `CONFIGURATIONS`.
595    Args:
596        name: The name of the model to be trained.
597            The checkpoint and logs wil have this name.
598        configuration: The configuration (= name of hardware resource).
599        train_loader: The dataloader for training.
600        val_loader: The dataloader for validation.
601        checkpoint_path: Path to checkpoint for initializing the SAM model.
602        with_segmentation_decoder: Whether to train additional UNETR decoder
603            for automatic instance segmentation.
604        model_type: Over-ride the default model type.
605            This can be used to use one of the micro_sam models as starting point
606            instead of a default sam model.
607        kwargs: Additional keyword parameters that will be passed to `train_sam`.
608    """
609    if configuration in CONFIGURATIONS:
610        train_kwargs = CONFIGURATIONS[configuration]
611    else:
612        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
614    if model_type is None:
615        model_type = train_kwargs.pop("model_type")
616    else:
617        expected_model_type = train_kwargs.pop("model_type")
618        if model_type[:5] != expected_model_type:
619            warnings.warn("You have specified a different model type.")
621    train_kwargs.update(**kwargs)
622    train_sam(
623        name=name,
624        train_loader=train_loader,
625        val_loader=val_loader,
626        checkpoint_path=checkpoint_path,
627        with_segmentation_decoder=with_segmentation_decoder,
628        model_type=model_type,
629        **train_kwargs
630    )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.