micro_sam.training.training

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  8
  9import imageio.v3 as imageio
 10
 11import torch
 12from torch.optim import Optimizer
 13from torch.utils.data import random_split
 14from torch.utils.data import DataLoader, Dataset
 15from torch.optim.lr_scheduler import _LRScheduler
 16
 17import torch_em
 18from torch_em.util import load_data
 19from torch_em.data.datasets.util import split_kwargs
 20
 21from elf.io import open_file
 22
 23try:
 24    from qtpy.QtCore import QObject
 25except Exception:
 26    QObject = Any
 27
 28from . import sam_trainer as trainers
 29from ..instance_segmentation import get_unetr
 30from . import joint_sam_trainer as joint_trainers
 31from ..util import get_device, get_model_names, export_custom_sam_model
 32from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
 33
 34
 35FilePath = Union[str, os.PathLike]
 36
 37
 38def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 39    x, _ = next(iter(loader))
 40
 41    # Raw data: check that we have 1 or 3 channels.
 42    n_channels = x.shape[1]
 43    if n_channels not in (1, 3):
 44        raise ValueError(
 45            "Invalid number of channels for the input data from the data loader. "
 46            f"Expect 1 or 3 channels, got {n_channels}."
 47        )
 48
 49    # Raw data: check that it is between [0, 255]
 50    minval, maxval = x.min(), x.max()
 51    if minval < 0 or minval > 255:
 52        raise ValueError(
 53            "Invalid data range for the input data from the data loader. "
 54            f"The input has to be in range [0, 255], but got minimum value {minval}."
 55        )
 56    if maxval < 1 or maxval > 255:
 57        raise ValueError(
 58            "Invalid data range for the input data from the data loader. "
 59            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 60        )
 61
 62    # Target data: the check depends on whether we train with or without decoder.
 63    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 64
 65    def _check_instance_channel(instance_channel):
 66        unique_vals = torch.unique(instance_channel)
 67        if (unique_vals < 0).any():
 68            raise ValueError(
 69                "The target channel with the instance segmentation must not have negative values."
 70            )
 71        if len(unique_vals) == 1:
 72            raise ValueError(
 73                "The target channel with the instance segmentation must have at least one instance."
 74            )
 75        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 76            raise ValueError(
 77                "All values in the target channel with the instance segmentation must be integer."
 78            )
 79
 80    counter = 0
 81    name = "" if name is None else f"'{name}'"
 82    for x, y in tqdm(
 83        loader,
 84        desc=f"Verifying labels in {name} dataloader",
 85        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 86    ):
 87        n_channels_y = y.shape[1]
 88        if with_segmentation_decoder:
 89            if n_channels_y != 4:
 90                raise ValueError(
 91                    "Invalid number of channels in the target data from the data loader. "
 92                    "Expect 4 channel for training with an instance segmentation decoder, "
 93                    f"but got {n_channels_y} channels."
 94                )
 95            # Check instance channel per sample in a batch
 96            for per_y_sample in y:
 97                _check_instance_channel(per_y_sample[0])
 98
 99            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
100            if targets_min < 0 or targets_min > 1:
101                raise ValueError(
102                    "Invalid value range in the target data from the value loader. "
103                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
104                    f"to be in range [0.0, 1.0], but got min {targets_min}"
105                )
106            if targets_max < 0 or targets_max > 1:
107                raise ValueError(
108                    "Invalid value range in the target data from the value loader. "
109                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
110                    f"to be in range [0.0, 1.0], but got max {targets_max}"
111                )
112
113        else:
114            if n_channels_y != 1:
115                raise ValueError(
116                    "Invalid number of channels in the target data from the data loader. "
117                    "Expect 1 channel for training without an instance segmentation decoder, "
118                    f"but got {n_channels_y} channels."
119                )
120            # Check instance channel per sample in a batch
121            for per_y_sample in y:
122                _check_instance_channel(per_y_sample)
123
124        counter += 1
125        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
126            break
127
128
129# Make the progress bar callbacks compatible with a tqdm progress bar interface.
130class _ProgressBarWrapper:
131    def __init__(self, signals):
132        self._signals = signals
133        self._total = None
134
135    @property
136    def total(self):
137        return self._total
138
139    @total.setter
140    def total(self, value):
141        self._signals.pbar_total.emit(value)
142        self._total = value
143
144    def update(self, steps):
145        self._signals.pbar_update.emit(steps)
146
147    def set_description(self, desc, **kwargs):
148        self._signals.pbar_description.emit(desc)
149
150
151@contextmanager
152def _filter_warnings(ignore_warnings):
153    if ignore_warnings:
154        with warnings.catch_warnings():
155            warnings.simplefilter("ignore")
156            yield
157    else:
158        with nullcontext():
159            yield
160
161
162def _count_parameters(model_parameters):
163    params = sum(p.numel() for p in model_parameters if p.requires_grad)
164    params = params / 1e6
165    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
166
167
168def train_sam(
169    name: str,
170    model_type: str,
171    train_loader: DataLoader,
172    val_loader: DataLoader,
173    n_epochs: int = 100,
174    early_stopping: Optional[int] = 10,
175    n_objects_per_batch: Optional[int] = 25,
176    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
177    with_segmentation_decoder: bool = True,
178    freeze: Optional[List[str]] = None,
179    device: Optional[Union[str, torch.device]] = None,
180    lr: float = 1e-5,
181    n_sub_iteration: int = 8,
182    save_root: Optional[Union[str, os.PathLike]] = None,
183    mask_prob: float = 0.5,
184    n_iterations: Optional[int] = None,
185    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
186    scheduler_kwargs: Optional[Dict[str, Any]] = None,
187    save_every_kth_epoch: Optional[int] = None,
188    pbar_signals: Optional[QObject] = None,
189    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
190    peft_kwargs: Optional[Dict] = None,
191    ignore_warnings: bool = True,
192    verify_n_labels_in_loader: Optional[int] = 50,
193    box_distortion_factor: Optional[float] = 0.025,
194    overwrite_training: bool = True,
195    **model_kwargs,
196) -> None:
197    """Run training for a SAM model.
198
199    Args:
200        name: The name of the model to be trained. The checkpoint and logs will have this name.
201        model_type: The type of the SAM model.
202        train_loader: The dataloader for training.
203        val_loader: The dataloader for validation.
204        n_epochs: The number of epochs to train for.
205        early_stopping: Enable early stopping after this number of epochs without improvement.
206        n_objects_per_batch: The number of objects per batch used to compute
207            the loss for interative segmentation. If None all objects will be used,
208            if given objects will be randomly sub-sampled.
209        checkpoint_path: Path to checkpoint for initializing the SAM model.
210        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        ignore_warnings: Whether to ignore raised warnings.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
233        overwrite_training: Whether to overwrite the trained model stored at the same location.
234            By default, overwrites the trained model at each run.
235            If set to 'False', it will avoid retraining the model if the previous run was completed.
236        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
237    """
238    with _filter_warnings(ignore_warnings):
239
240        t_start = time.time()
241
242        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
243        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
244
245        device = get_device(device)
246        # Get the trainable segment anything model.
247        model, state = get_trainable_sam_model(
248            model_type=model_type,
249            device=device,
250            freeze=freeze,
251            checkpoint_path=checkpoint_path,
252            return_state=True,
253            peft_kwargs=peft_kwargs,
254            **model_kwargs
255        )
256
257        # This class creates all the training data for a batch (inputs, prompts and labels).
258        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
259
260        # Create the UNETR decoder (if train with it) and the optimizer.
261        if with_segmentation_decoder:
262
263            # Get the UNETR.
264            unetr = get_unetr(
265                image_encoder=model.sam.image_encoder,
266                decoder_state=state.get("decoder_state", None),
267                device=device,
268            )
269
270            # Get the parameters for SAM and the decoder from UNETR.
271            joint_model_params = [params for params in model.parameters()]  # sam parameters
272            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
273                if not param_name.startswith("encoder"):
274                    joint_model_params.append(params)
275
276            model_params = joint_model_params
277        else:
278            model_params = model.parameters()
279
280        optimizer = optimizer_class(model_params, lr=lr)
281
282        if scheduler_kwargs is None:
283            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
284
285        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
286
287        # The trainer which performs training and validation.
288        if with_segmentation_decoder:
289            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
290            trainer = joint_trainers.JointSamTrainer(
291                name=name,
292                save_root=save_root,
293                train_loader=train_loader,
294                val_loader=val_loader,
295                model=model,
296                optimizer=optimizer,
297                device=device,
298                lr_scheduler=scheduler,
299                logger=joint_trainers.JointSamLogger,
300                log_image_interval=100,
301                mixed_precision=True,
302                convert_inputs=convert_inputs,
303                n_objects_per_batch=n_objects_per_batch,
304                n_sub_iteration=n_sub_iteration,
305                compile_model=False,
306                unetr=unetr,
307                instance_loss=instance_seg_loss,
308                instance_metric=instance_seg_loss,
309                early_stopping=early_stopping,
310                mask_prob=mask_prob,
311            )
312        else:
313            trainer = trainers.SamTrainer(
314                name=name,
315                train_loader=train_loader,
316                val_loader=val_loader,
317                model=model,
318                optimizer=optimizer,
319                device=device,
320                lr_scheduler=scheduler,
321                logger=trainers.SamLogger,
322                log_image_interval=100,
323                mixed_precision=True,
324                convert_inputs=convert_inputs,
325                n_objects_per_batch=n_objects_per_batch,
326                n_sub_iteration=n_sub_iteration,
327                compile_model=False,
328                early_stopping=early_stopping,
329                mask_prob=mask_prob,
330                save_root=save_root,
331            )
332
333        if n_iterations is None:
334            trainer_fit_params = {"epochs": n_epochs}
335        else:
336            trainer_fit_params = {"iterations": n_iterations}
337
338        if save_every_kth_epoch is not None:
339            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
340
341        if pbar_signals is not None:
342            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
343            trainer_fit_params["progress"] = progress_bar_wrapper
344
345        # Avoid overwriting a trained model, if desired by the user.
346        trainer_fit_params["overwrite_training"] = overwrite_training
347
348        trainer.fit(**trainer_fit_params)
349
350        t_run = time.time() - t_start
351        hours = int(t_run // 3600)
352        minutes = int(t_run // 60)
353        seconds = int(round(t_run % 60, 0))
354        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
355
356
357def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
358    if isinstance(raw_paths, (str, os.PathLike)):
359        path = raw_paths
360    else:
361        path = raw_paths[0]
362    assert isinstance(path, (str, os.PathLike))
363
364    # Check the underlying data dimensionality.
365    if raw_key is None:  # If no key is given then we assume it's an image file.
366        ndim = imageio.imread(path).ndim
367    else:  # Otherwise we try to open the file from key.
368        try:  # First try to open it with elf.
369            with open_file(path, "r") as f:
370                ndim = f[raw_key].ndim
371        except ValueError:  # This may fail for images in a folder with different sizes.
372            # In that case we read one of the images.
373            image_path = glob(os.path.join(path, raw_key))[0]
374            ndim = imageio.imread(image_path).ndim
375
376    if not isinstance(patch_shape, tuple):
377        patch_shape = tuple(patch_shape)
378
379    if ndim == 2:
380        assert len(patch_shape) == 2
381        return patch_shape
382    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
383        return (1,) + patch_shape
384    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
385        return (1,) + patch_shape
386    else:
387        return patch_shape
388
389
390def default_sam_dataset(
391    raw_paths: Union[List[FilePath], FilePath],
392    raw_key: Optional[str],
393    label_paths: Union[List[FilePath], FilePath],
394    label_key: Optional[str],
395    patch_shape: Tuple[int],
396    with_segmentation_decoder: bool,
397    with_channels: Optional[bool] = None,
398    sampler: Optional[Callable] = None,
399    raw_transform: Optional[Callable] = None,
400    n_samples: Optional[int] = None,
401    is_train: bool = True,
402    min_size: int = 25,
403    max_sampling_attempts: Optional[int] = None,
404    **kwargs,
405) -> Dataset:
406    """Create a PyTorch Dataset for training a SAM model.
407
408    Args:
409        raw_paths: The path(s) to the image data used for training.
410            Can either be multiple 2D images or volumetric data.
411        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
412            or a glob pattern for selecting multiple files.
413        label_paths: The path(s) to the label data used for training.
414            Can either be multiple 2D images or volumetric data.
415        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
416            or a glob pattern for selecting multiple files.
417        patch_shape: The shape for training patches.
418        with_segmentation_decoder: Whether to train with additional segmentation decoder.
419        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
420        sampler: A sampler to reject batches according to a given criterion.
421        raw_transform: Transformation applied to the image data.
422            If not given the data will be cast to 8bit.
423        n_samples: The number of samples for this dataset.
424        is_train: Whether this dataset is used for training or validation.
425        min_size: Minimal object size. Smaller objects will be filtered.
426        max_sampling_attempts: Number of sampling attempts to make from a dataset.
427        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
428
429    Returns:
430        The segmentation dataset.
431    """
432
433    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
434    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
435
436    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
437    # Get valid raw paths to make checks possible.
438    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
439        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
440    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
441        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
442
443    # Load one of the raw inputs to validate whether it is RGB or not.
444    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
445    if test_raw_inputs.ndim == 3:
446        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
447            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
448            # We need to provide a list of inputs to 'ImageCollectionDataset'.
449            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
450            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
451
452            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
453            with_channels = False if with_channels is None else with_channels
454
455        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
456            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
457            with_channels = True if with_channels is None else with_channels
458
459    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
460    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
461    with_channels = False if with_channels is None else with_channels
462
463    # Set the data transformations.
464    if raw_transform is None:
465        raw_transform = require_8bit
466
467    if with_segmentation_decoder:
468        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
469            distances=True,
470            boundary_distances=True,
471            directed_distances=False,
472            foreground=True,
473            instances=True,
474            min_size=min_size,
475        )
476    else:
477        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
478
479    # Set a default sampler if none was passed.
480    if sampler is None:
481        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
482
483    # Check the patch shape to add a singleton if required.
484    patch_shape = _update_patch_shape(
485        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
486    )
487
488    # Set a minimum number of samples per epoch.
489    if n_samples is None:
490        loader = torch_em.default_segmentation_loader(
491            raw_paths=raw_paths,
492            raw_key=raw_key,
493            label_paths=label_paths,
494            label_key=label_key,
495            batch_size=1,
496            patch_shape=patch_shape,
497            with_channels=with_channels,
498            ndim=2,
499            is_seg_dataset=is_seg_dataset,
500            raw_transform=raw_transform,
501            **kwargs
502        )
503        n_samples = max(len(loader), 100 if is_train else 5)
504
505    dataset = torch_em.default_segmentation_dataset(
506        raw_paths=raw_paths,
507        raw_key=raw_key,
508        label_paths=label_paths,
509        label_key=label_key,
510        patch_shape=patch_shape,
511        raw_transform=raw_transform,
512        label_transform=label_transform,
513        with_channels=with_channels,
514        ndim=2,
515        sampler=sampler,
516        n_samples=n_samples,
517        is_seg_dataset=is_seg_dataset,
518        **kwargs,
519    )
520
521    if max_sampling_attempts is not None:
522        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
523            for ds in dataset.datasets:
524                ds.max_sampling_attempts = max_sampling_attempts
525        else:
526            dataset.max_sampling_attempts = max_sampling_attempts
527
528    return dataset
529
530
531def default_sam_loader(**kwargs) -> DataLoader:
532    """Create a PyTorch DataLoader for training a SAM model.
533
534    Args:
535        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
536
537    Returns:
538        The DataLoader.
539    """
540    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
541
542    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
543    # which the users can provide to get their desired segmentation dataset.
544    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
545    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
546
547    ds = default_sam_dataset(**ds_kwargs)
548    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
549
550
551CONFIGURATIONS = {
552    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
553    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
554    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
555    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
556    "V100": {"model_type": "vit_b"},
557    "A100": {"model_type": "vit_h"},
558}
559
560
561def _find_best_configuration():
562    if torch.cuda.is_available():
563
564        # Check how much memory we have and select the best matching GPU
565        # for the available VRAM size.
566        _, vram = torch.cuda.mem_get_info()
567        vram = vram / 1e9  # in GB
568
569        # Maybe we can get more configurations in the future.
570        if vram > 80:  # More than 80 GB: use the A100 configurations.
571            return "A100"
572        elif vram > 30:  # More than 30 GB: use the V100 configurations.
573            return "V100"
574        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
575            return "rtx5000"
576        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
577            return "CPU"
578    else:
579        return "CPU"
580
581
582"""Best training configurations for given hardware resources.
583"""
584
585
586def train_sam_for_configuration(
587    name: str,
588    configuration: str,
589    train_loader: DataLoader,
590    val_loader: DataLoader,
591    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
592    with_segmentation_decoder: bool = True,
593    model_type: Optional[str] = None,
594    **kwargs,
595) -> None:
596    """Run training for a SAM model with the configuration for a given hardware resource.
597
598    Selects the best training settings for the given configuration.
599    The available configurations are listed in `CONFIGURATIONS`.
600
601    Args:
602        name: The name of the model to be trained.
603            The checkpoint and logs wil have this name.
604        configuration: The configuration (= name of hardware resource).
605        train_loader: The dataloader for training.
606        val_loader: The dataloader for validation.
607        checkpoint_path: Path to checkpoint for initializing the SAM model.
608        with_segmentation_decoder: Whether to train additional UNETR decoder
609            for automatic instance segmentation.
610        model_type: Over-ride the default model type.
611            This can be used to use one of the micro_sam models as starting point
612            instead of a default sam model.
613        kwargs: Additional keyword parameters that will be passed to `train_sam`.
614    """
615    if configuration in CONFIGURATIONS:
616        train_kwargs = CONFIGURATIONS[configuration]
617    else:
618        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
619
620    if model_type is None:
621        model_type = train_kwargs.pop("model_type")
622    else:
623        expected_model_type = train_kwargs.pop("model_type")
624        if model_type[:5] != expected_model_type:
625            warnings.warn("You have specified a different model type.")
626
627    train_kwargs.update(**kwargs)
628    train_sam(
629        name=name,
630        train_loader=train_loader,
631        val_loader=val_loader,
632        checkpoint_path=checkpoint_path,
633        with_segmentation_decoder=with_segmentation_decoder,
634        model_type=model_type,
635        **train_kwargs
636    )
637
638
639def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
640
641    # Whether the model is stored in the current working directory or in another location.
642    if save_root is None:
643        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
644
645    # Get the 'best' model checkpoint ready for export.
646    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
647    if not os.path.exists(best_checkpoint):
648        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
649
650    # Export the model if an output path has been given.
651    if output_path:
652
653        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
654        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
655            export_custom_sam_model(
656                checkpoint_path=best_checkpoint,
657                model_type=model_type[:5],
658                save_path=output_path,
659                with_segmentation_decoder=with_segmentation_decoder,
660            )
661
662        # Otherwise we export it as bioimage.io model.
663        else:
664            from micro_sam.bioimageio import export_sam_model
665
666            # Load image and corresponding labels from the val loader.
667            with torch.no_grad():
668                image_data, label_data = next(iter(val_loader))
669                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
670
671                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
672                if label_data.ndim == 3:
673                    label_data = label_data[0]  # Gets the channel with instances.
674                assert image_data.shape == label_data.shape
675                label_data = label_data.astype("uint32")
676
677                export_sam_model(
678                    image=image_data,
679                    label_image=label_data,
680                    model_type=model_type[:5],
681                    name=checkpoint_name,
682                    output_path=output_path,
683                    checkpoint_path=best_checkpoint,
684                )
685
686        # The final path where the model has been stored.
687        final_path = output_path
688
689    else:  # If no exports have been made, inform the user about the best checkpoint.
690        final_path = best_checkpoint
691
692    return final_path
693
694
695def main():
696    """@private"""
697    import argparse
698
699    available_models = list(get_model_names())
700    available_models = ", ".join(available_models)
701
702    available_configurations = list(CONFIGURATIONS.keys())
703    available_configurations = ", ".join(available_configurations)
704
705    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
706
707    # Images and labels for training.
708    parser.add_argument(
709        "--images", required=True, type=str, nargs="*",
710        help="Filepath to images or the directory where the image data is stored."
711    )
712    parser.add_argument(
713        "--labels", required=True, type=str, nargs="*",
714        help="Filepath to ground-truth labels or the directory where the label data is stored."
715    )
716    parser.add_argument(
717        "--image_key", type=str, default=None,
718        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
719    )
720    parser.add_argument(
721        "--label_key", type=str, default=None,
722        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
723    )
724
725    # Images and labels for validation.
726    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
727    # Users can choose to have their explicit validation set via this feature as well.
728    parser.add_argument(
729        "--val_images", type=str, nargs="*",
730        help="Filepath to images for validation or the directory where the image data is stored."
731    )
732    parser.add_argument(
733        "--val_labels", type=str, nargs="*",
734        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
735    )
736    parser.add_argument(
737        "--val_image_key", type=str, default=None,
738        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
739    )
740    parser.add_argument(
741        "--val_label_key", type=str, default=None,
742        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
743    )
744
745    # Other necessary stuff for training.
746    parser.add_argument(
747        "--configuration", type=str, default=_find_best_configuration(),
748        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
749    )
750
751    def none_or_str(value):
752        if value.lower() == 'none':
753            return None
754        return value
755
756    parser.add_argument(
757        "--segmentation_decoder", type=none_or_str, default="instances",
758        # TODO: in future, we can extend this to semantic seg / or even more advanced stuff.
759        help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. "
760        "By default, it uses the 'instances' option, i.e. trains with the additional segmentation decoder for "
761        "instance segmentation, otherwise pass 'None' for training without the additional segmentation decoder at all."
762    )
763
764    # Optional advanced settings a user can opt to change the values for.
765    parser.add_argument(
766        "-d", "--device", type=str, default=None,
767        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
768        "By default the most performant available device will be selected."
769    )
770    parser.add_argument(
771        "--patch_shape", type=int, nargs="*", default=(512, 512),
772        help="The choice of patch shape for training Segment Anything Model. "
773        "By default, a patch size of 512x512 is used."
774    )
775    parser.add_argument(
776        "-m", "--model_type", type=str, default=None,
777        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
778    )
779    parser.add_argument(
780        "--checkpoint_path", type=str, default=None,
781        help="Checkpoint from which the SAM model will be loaded for finetuning."
782    )
783    parser.add_argument(
784        "-s", "--save_root", type=str, default=None,
785        help="The directory where the trained models and corresponding logs will be stored. "
786        "By default, there are stored in your current working directory."
787    )
788    parser.add_argument(
789        "--trained_model_name", type=str, default="sam_model",
790        help="The custom name of trained model sub-folder. Allows users to have several trained models "
791        "under the same 'save_root'."
792    )
793    parser.add_argument(
794        "--output_path", type=str, default=None,
795        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
796    )
797    parser.add_argument(
798        "--n_epochs", type=int, default=100,
799        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
800    )
801    parser.add_argument(
802        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
803    )
804    parser.add_argument(
805        "--batch_size", type=int, default=1,
806        help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1."
807    )
808    parser.add_argument(
809        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
810        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
811        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
812    )
813
814    args = parser.parse_args()
815
816    # 1. Get all necessary stuff for training.
817    checkpoint_name = args.trained_model_name
818    config = args.configuration
819    model_type = args.model_type
820    checkpoint_path = args.checkpoint_path
821    batch_size = args.batch_size
822    patch_shape = args.patch_shape
823    epochs = args.n_epochs
824    num_workers = args.num_workers
825    device = args.device
826    save_root = args.save_root
827    output_path = args.output_path
828
829    if args.segmentation_decoder and args.segmentation_deocder != "instances":
830        raise ValueError(
831            "The 'segmentation_decoder' argument currently supports 'instances' as input argument only."
832        )
833    with_segmentation_decoder = (args.segmentation_decoder is not None)
834
835    # Get image paths and corresponding keys.
836    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
837    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
838
839    # 2. Prepare the dataloaders.
840
841    # If the user wants to preprocess the inputs, we allow the possibility to do so.
842    _raw_transform = get_raw_transform(args.preprocess)
843
844    # Get the dataset with files for training.
845    dataset = default_sam_dataset(
846        raw_paths=train_images,
847        raw_key=train_image_key,
848        label_paths=train_gt,
849        label_key=train_gt_key,
850        patch_shape=patch_shape,
851        with_segmentation_decoder=with_segmentation_decoder,
852        raw_transform=_raw_transform,
853    )
854
855    # If val images are not exclusively provided, we create a val split from the training data.
856    if val_images is None:
857        assert val_gt is None and val_image_key is None and val_gt_key is None
858        # Use 10% of the dataset for validation - at least one image - for validation.
859        n_val = max(1, int(0.1 * len(dataset)))
860        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
861
862    else:  # If val images provided, we create a new dataset for it.
863        train_dataset = dataset
864        val_dataset = default_sam_dataset(
865            raw_paths=val_images,
866            raw_key=val_image_key,
867            label_paths=val_gt,
868            label_key=val_gt_key,
869            patch_shape=patch_shape,
870            with_segmentation_decoder=with_segmentation_decoder,
871            raw_transform=_raw_transform,
872        )
873
874    # Get the dataloaders from the datasets.
875    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
876    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
877
878    # 3. Train the Segment Anything Model.
879
880    # Get a valid model and other necessary parameters for training.
881    if model_type is not None and model_type not in available_models:
882        raise ValueError(f"'{model_type}' is not a valid choice of model.")
883    if config is not None and config not in available_configurations:
884        raise ValueError(f"'{config}' is not a valid choice of configuration.")
885
886    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
887        model_type = CONFIGURATIONS[config]["model_type"]
888
889    train_sam_for_configuration(
890        name=checkpoint_name,
891        configuration=config,
892        model_type=model_type,
893        train_loader=train_loader,
894        val_loader=val_loader,
895        n_epochs=epochs,
896        checkpoint_path=checkpoint_path,
897        with_segmentation_decoder=with_segmentation_decoder,
898        freeze=None,  # TODO: Allow for PEFT.
899        device=device,
900        save_root=save_root,
901        peft_kwargs=None,  # TODO: Allow for PEFT.
902    )
903
904    # 4. Export the model, if desired by the user
905    final_path = _export_helper(
906        save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader
907    )
908
909    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[str, os.PathLike, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, overwrite_training: bool = True, **model_kwargs) -> None:
169def train_sam(
170    name: str,
171    model_type: str,
172    train_loader: DataLoader,
173    val_loader: DataLoader,
174    n_epochs: int = 100,
175    early_stopping: Optional[int] = 10,
176    n_objects_per_batch: Optional[int] = 25,
177    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
178    with_segmentation_decoder: bool = True,
179    freeze: Optional[List[str]] = None,
180    device: Optional[Union[str, torch.device]] = None,
181    lr: float = 1e-5,
182    n_sub_iteration: int = 8,
183    save_root: Optional[Union[str, os.PathLike]] = None,
184    mask_prob: float = 0.5,
185    n_iterations: Optional[int] = None,
186    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
187    scheduler_kwargs: Optional[Dict[str, Any]] = None,
188    save_every_kth_epoch: Optional[int] = None,
189    pbar_signals: Optional[QObject] = None,
190    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
191    peft_kwargs: Optional[Dict] = None,
192    ignore_warnings: bool = True,
193    verify_n_labels_in_loader: Optional[int] = 50,
194    box_distortion_factor: Optional[float] = 0.025,
195    overwrite_training: bool = True,
196    **model_kwargs,
197) -> None:
198    """Run training for a SAM model.
199
200    Args:
201        name: The name of the model to be trained. The checkpoint and logs will have this name.
202        model_type: The type of the SAM model.
203        train_loader: The dataloader for training.
204        val_loader: The dataloader for validation.
205        n_epochs: The number of epochs to train for.
206        early_stopping: Enable early stopping after this number of epochs without improvement.
207        n_objects_per_batch: The number of objects per batch used to compute
208            the loss for interative segmentation. If None all objects will be used,
209            if given objects will be randomly sub-sampled.
210        checkpoint_path: Path to checkpoint for initializing the SAM model.
211        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
212        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
213            By default nothing is frozen and the full model is updated.
214        device: The device to use for training.
215        lr: The learning rate.
216        n_sub_iteration: The number of iterative prompts per training iteration.
217        save_root: Optional root directory for saving the checkpoints and logs.
218            If not given the current working directory is used.
219        mask_prob: The probability for using a mask as input in a given training sub-iteration.
220        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
221        scheduler_class: The learning rate scheduler to update the learning rate.
222            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
223        scheduler_kwargs: The learning rate scheduler parameters.
224            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
225        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
226        pbar_signals: Controls for napari progress bar.
227        optimizer_class: The optimizer class.
228            By default, torch.optim.AdamW is used.
229        peft_kwargs: Keyword arguments for the PEFT wrapper class.
230        ignore_warnings: Whether to ignore raised warnings.
231        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
232            By default, 50 batches of labels are verified from the dataloaders.
233        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
234        overwrite_training: Whether to overwrite the trained model stored at the same location.
235            By default, overwrites the trained model at each run.
236            If set to 'False', it will avoid retraining the model if the previous run was completed.
237        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
238    """
239    with _filter_warnings(ignore_warnings):
240
241        t_start = time.time()
242
243        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
244        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
245
246        device = get_device(device)
247        # Get the trainable segment anything model.
248        model, state = get_trainable_sam_model(
249            model_type=model_type,
250            device=device,
251            freeze=freeze,
252            checkpoint_path=checkpoint_path,
253            return_state=True,
254            peft_kwargs=peft_kwargs,
255            **model_kwargs
256        )
257
258        # This class creates all the training data for a batch (inputs, prompts and labels).
259        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
260
261        # Create the UNETR decoder (if train with it) and the optimizer.
262        if with_segmentation_decoder:
263
264            # Get the UNETR.
265            unetr = get_unetr(
266                image_encoder=model.sam.image_encoder,
267                decoder_state=state.get("decoder_state", None),
268                device=device,
269            )
270
271            # Get the parameters for SAM and the decoder from UNETR.
272            joint_model_params = [params for params in model.parameters()]  # sam parameters
273            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
274                if not param_name.startswith("encoder"):
275                    joint_model_params.append(params)
276
277            model_params = joint_model_params
278        else:
279            model_params = model.parameters()
280
281        optimizer = optimizer_class(model_params, lr=lr)
282
283        if scheduler_kwargs is None:
284            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
285
286        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
287
288        # The trainer which performs training and validation.
289        if with_segmentation_decoder:
290            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
291            trainer = joint_trainers.JointSamTrainer(
292                name=name,
293                save_root=save_root,
294                train_loader=train_loader,
295                val_loader=val_loader,
296                model=model,
297                optimizer=optimizer,
298                device=device,
299                lr_scheduler=scheduler,
300                logger=joint_trainers.JointSamLogger,
301                log_image_interval=100,
302                mixed_precision=True,
303                convert_inputs=convert_inputs,
304                n_objects_per_batch=n_objects_per_batch,
305                n_sub_iteration=n_sub_iteration,
306                compile_model=False,
307                unetr=unetr,
308                instance_loss=instance_seg_loss,
309                instance_metric=instance_seg_loss,
310                early_stopping=early_stopping,
311                mask_prob=mask_prob,
312            )
313        else:
314            trainer = trainers.SamTrainer(
315                name=name,
316                train_loader=train_loader,
317                val_loader=val_loader,
318                model=model,
319                optimizer=optimizer,
320                device=device,
321                lr_scheduler=scheduler,
322                logger=trainers.SamLogger,
323                log_image_interval=100,
324                mixed_precision=True,
325                convert_inputs=convert_inputs,
326                n_objects_per_batch=n_objects_per_batch,
327                n_sub_iteration=n_sub_iteration,
328                compile_model=False,
329                early_stopping=early_stopping,
330                mask_prob=mask_prob,
331                save_root=save_root,
332            )
333
334        if n_iterations is None:
335            trainer_fit_params = {"epochs": n_epochs}
336        else:
337            trainer_fit_params = {"iterations": n_iterations}
338
339        if save_every_kth_epoch is not None:
340            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
341
342        if pbar_signals is not None:
343            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
344            trainer_fit_params["progress"] = progress_bar_wrapper
345
346        # Avoid overwriting a trained model, if desired by the user.
347        trainer_fit_params["overwrite_training"] = overwrite_training
348
349        trainer.fit(**trainer_fit_params)
350
351        t_run = time.time() - t_start
352        hours = int(t_run // 3600)
353        minutes = int(t_run // 60)
354        seconds = int(round(t_run % 60, 0))
355        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
  • overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[os.PathLike, str]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[os.PathLike, str]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
391def default_sam_dataset(
392    raw_paths: Union[List[FilePath], FilePath],
393    raw_key: Optional[str],
394    label_paths: Union[List[FilePath], FilePath],
395    label_key: Optional[str],
396    patch_shape: Tuple[int],
397    with_segmentation_decoder: bool,
398    with_channels: Optional[bool] = None,
399    sampler: Optional[Callable] = None,
400    raw_transform: Optional[Callable] = None,
401    n_samples: Optional[int] = None,
402    is_train: bool = True,
403    min_size: int = 25,
404    max_sampling_attempts: Optional[int] = None,
405    **kwargs,
406) -> Dataset:
407    """Create a PyTorch Dataset for training a SAM model.
408
409    Args:
410        raw_paths: The path(s) to the image data used for training.
411            Can either be multiple 2D images or volumetric data.
412        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
413            or a glob pattern for selecting multiple files.
414        label_paths: The path(s) to the label data used for training.
415            Can either be multiple 2D images or volumetric data.
416        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
417            or a glob pattern for selecting multiple files.
418        patch_shape: The shape for training patches.
419        with_segmentation_decoder: Whether to train with additional segmentation decoder.
420        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
421        sampler: A sampler to reject batches according to a given criterion.
422        raw_transform: Transformation applied to the image data.
423            If not given the data will be cast to 8bit.
424        n_samples: The number of samples for this dataset.
425        is_train: Whether this dataset is used for training or validation.
426        min_size: Minimal object size. Smaller objects will be filtered.
427        max_sampling_attempts: Number of sampling attempts to make from a dataset.
428        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
429
430    Returns:
431        The segmentation dataset.
432    """
433
434    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
435    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
436
437    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
438    # Get valid raw paths to make checks possible.
439    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
440        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
441    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
442        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
443
444    # Load one of the raw inputs to validate whether it is RGB or not.
445    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
446    if test_raw_inputs.ndim == 3:
447        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
448            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
449            # We need to provide a list of inputs to 'ImageCollectionDataset'.
450            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
451            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
452
453            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
454            with_channels = False if with_channels is None else with_channels
455
456        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
457            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
458            with_channels = True if with_channels is None else with_channels
459
460    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
461    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
462    with_channels = False if with_channels is None else with_channels
463
464    # Set the data transformations.
465    if raw_transform is None:
466        raw_transform = require_8bit
467
468    if with_segmentation_decoder:
469        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
470            distances=True,
471            boundary_distances=True,
472            directed_distances=False,
473            foreground=True,
474            instances=True,
475            min_size=min_size,
476        )
477    else:
478        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
479
480    # Set a default sampler if none was passed.
481    if sampler is None:
482        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
483
484    # Check the patch shape to add a singleton if required.
485    patch_shape = _update_patch_shape(
486        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
487    )
488
489    # Set a minimum number of samples per epoch.
490    if n_samples is None:
491        loader = torch_em.default_segmentation_loader(
492            raw_paths=raw_paths,
493            raw_key=raw_key,
494            label_paths=label_paths,
495            label_key=label_key,
496            batch_size=1,
497            patch_shape=patch_shape,
498            with_channels=with_channels,
499            ndim=2,
500            is_seg_dataset=is_seg_dataset,
501            raw_transform=raw_transform,
502            **kwargs
503        )
504        n_samples = max(len(loader), 100 if is_train else 5)
505
506    dataset = torch_em.default_segmentation_dataset(
507        raw_paths=raw_paths,
508        raw_key=raw_key,
509        label_paths=label_paths,
510        label_key=label_key,
511        patch_shape=patch_shape,
512        raw_transform=raw_transform,
513        label_transform=label_transform,
514        with_channels=with_channels,
515        ndim=2,
516        sampler=sampler,
517        n_samples=n_samples,
518        is_seg_dataset=is_seg_dataset,
519        **kwargs,
520    )
521
522    if max_sampling_attempts is not None:
523        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
524            for ds in dataset.datasets:
525                ds.max_sampling_attempts = max_sampling_attempts
526        else:
527            dataset.max_sampling_attempts = max_sampling_attempts
528
529    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
532def default_sam_loader(**kwargs) -> DataLoader:
533    """Create a PyTorch DataLoader for training a SAM model.
534
535    Args:
536        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
537
538    Returns:
539        The DataLoader.
540    """
541    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
542
543    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
544    # which the users can provide to get their desired segmentation dataset.
545    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
546    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
547
548    ds = default_sam_dataset(**ds_kwargs)
549    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
def train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[str, os.PathLike, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
587def train_sam_for_configuration(
588    name: str,
589    configuration: str,
590    train_loader: DataLoader,
591    val_loader: DataLoader,
592    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
593    with_segmentation_decoder: bool = True,
594    model_type: Optional[str] = None,
595    **kwargs,
596) -> None:
597    """Run training for a SAM model with the configuration for a given hardware resource.
598
599    Selects the best training settings for the given configuration.
600    The available configurations are listed in `CONFIGURATIONS`.
601
602    Args:
603        name: The name of the model to be trained.
604            The checkpoint and logs wil have this name.
605        configuration: The configuration (= name of hardware resource).
606        train_loader: The dataloader for training.
607        val_loader: The dataloader for validation.
608        checkpoint_path: Path to checkpoint for initializing the SAM model.
609        with_segmentation_decoder: Whether to train additional UNETR decoder
610            for automatic instance segmentation.
611        model_type: Over-ride the default model type.
612            This can be used to use one of the micro_sam models as starting point
613            instead of a default sam model.
614        kwargs: Additional keyword parameters that will be passed to `train_sam`.
615    """
616    if configuration in CONFIGURATIONS:
617        train_kwargs = CONFIGURATIONS[configuration]
618    else:
619        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
620
621    if model_type is None:
622        model_type = train_kwargs.pop("model_type")
623    else:
624        expected_model_type = train_kwargs.pop("model_type")
625        if model_type[:5] != expected_model_type:
626            warnings.warn("You have specified a different model type.")
627
628    train_kwargs.update(**kwargs)
629    train_sam(
630        name=name,
631        train_loader=train_loader,
632        val_loader=val_loader,
633        checkpoint_path=checkpoint_path,
634        with_segmentation_decoder=with_segmentation_decoder,
635        model_type=model_type,
636        **train_kwargs
637    )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.