micro_sam.training.training

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from tqdm import tqdm
  6from contextlib import contextmanager, nullcontext
  7from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  8
  9import imageio.v3 as imageio
 10
 11import torch
 12from torch.optim import Optimizer
 13from torch.utils.data import random_split
 14from torch.utils.data import DataLoader, Dataset
 15from torch.optim.lr_scheduler import _LRScheduler
 16
 17import torch_em
 18from torch_em.util import load_data
 19from torch_em.data.datasets.util import split_kwargs
 20
 21from elf.io import open_file
 22
 23try:
 24    from qtpy.QtCore import QObject
 25except Exception:
 26    QObject = Any
 27
 28from . import sam_trainer as trainers
 29from ..instance_segmentation import get_unetr
 30from . import joint_sam_trainer as joint_trainers
 31from ..util import get_device, get_model_names, export_custom_sam_model
 32from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform
 33
 34
 35FilePath = Union[str, os.PathLike]
 36
 37
 38def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None):
 39    x, _ = next(iter(loader))
 40
 41    # Raw data: check that we have 1 or 3 channels.
 42    n_channels = x.shape[1]
 43    if n_channels not in (1, 3):
 44        raise ValueError(
 45            "Invalid number of channels for the input data from the data loader. "
 46            f"Expect 1 or 3 channels, got {n_channels}."
 47        )
 48
 49    # Raw data: check that it is between [0, 255]
 50    minval, maxval = x.min(), x.max()
 51    if minval < 0 or minval > 255:
 52        raise ValueError(
 53            "Invalid data range for the input data from the data loader. "
 54            f"The input has to be in range [0, 255], but got minimum value {minval}."
 55        )
 56    if maxval < 1 or maxval > 255:
 57        raise ValueError(
 58            "Invalid data range for the input data from the data loader. "
 59            f"The input has to be in range [0, 255], but got maximum value {maxval}."
 60        )
 61
 62    # Target data: the check depends on whether we train with or without decoder.
 63    # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance).
 64
 65    def _check_instance_channel(instance_channel):
 66        unique_vals = torch.unique(instance_channel)
 67        if (unique_vals < 0).any():
 68            raise ValueError(
 69                "The target channel with the instance segmentation must not have negative values."
 70            )
 71        if len(unique_vals) == 1:
 72            raise ValueError(
 73                "The target channel with the instance segmentation must have at least one instance."
 74            )
 75        if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
 76            raise ValueError(
 77                "All values in the target channel with the instance segmentation must be integer."
 78            )
 79
 80    counter = 0
 81    name = "" if name is None else f"'{name}'"
 82    for x, y in tqdm(
 83        loader,
 84        desc=f"Verifying labels in {name} dataloader",
 85        total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None,
 86    ):
 87        n_channels_y = y.shape[1]
 88        if with_segmentation_decoder:
 89            if n_channels_y != 4:
 90                raise ValueError(
 91                    "Invalid number of channels in the target data from the data loader. "
 92                    "Expect 4 channel for training with an instance segmentation decoder, "
 93                    f"but got {n_channels_y} channels."
 94                )
 95            # Check instance channel per sample in a batch
 96            for per_y_sample in y:
 97                _check_instance_channel(per_y_sample[0])
 98
 99            targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
100            if targets_min < 0 or targets_min > 1:
101                raise ValueError(
102                    "Invalid value range in the target data from the value loader. "
103                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
104                    f"to be in range [0.0, 1.0], but got min {targets_min}"
105                )
106            if targets_max < 0 or targets_max > 1:
107                raise ValueError(
108                    "Invalid value range in the target data from the value loader. "
109                    "Expect the 3 last target channels (for normalized distances and foreground probabilities) "
110                    f"to be in range [0.0, 1.0], but got max {targets_max}"
111                )
112
113        else:
114            if n_channels_y != 1:
115                raise ValueError(
116                    "Invalid number of channels in the target data from the data loader. "
117                    "Expect 1 channel for training without an instance segmentation decoder, "
118                    f"but got {n_channels_y} channels."
119                )
120            # Check instance channel per sample in a batch
121            for per_y_sample in y:
122                _check_instance_channel(per_y_sample)
123
124        counter += 1
125        if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader:
126            break
127
128
129# Make the progress bar callbacks compatible with a tqdm progress bar interface.
130class _ProgressBarWrapper:
131    def __init__(self, signals):
132        self._signals = signals
133        self._total = None
134
135    @property
136    def total(self):
137        return self._total
138
139    @total.setter
140    def total(self, value):
141        self._signals.pbar_total.emit(value)
142        self._total = value
143
144    def update(self, steps):
145        self._signals.pbar_update.emit(steps)
146
147    def set_description(self, desc, **kwargs):
148        self._signals.pbar_description.emit(desc)
149
150
151@contextmanager
152def _filter_warnings(ignore_warnings):
153    if ignore_warnings:
154        with warnings.catch_warnings():
155            warnings.simplefilter("ignore")
156            yield
157    else:
158        with nullcontext():
159            yield
160
161
162def _count_parameters(model_parameters):
163    params = sum(p.numel() for p in model_parameters if p.requires_grad)
164    params = params / 1e6
165    print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)")
166
167
168def train_sam(
169    name: str,
170    model_type: str,
171    train_loader: DataLoader,
172    val_loader: DataLoader,
173    n_epochs: int = 100,
174    early_stopping: Optional[int] = 10,
175    n_objects_per_batch: Optional[int] = 25,
176    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
177    with_segmentation_decoder: bool = True,
178    freeze: Optional[List[str]] = None,
179    device: Optional[Union[str, torch.device]] = None,
180    lr: float = 1e-5,
181    n_sub_iteration: int = 8,
182    save_root: Optional[Union[str, os.PathLike]] = None,
183    mask_prob: float = 0.5,
184    n_iterations: Optional[int] = None,
185    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
186    scheduler_kwargs: Optional[Dict[str, Any]] = None,
187    save_every_kth_epoch: Optional[int] = None,
188    pbar_signals: Optional[QObject] = None,
189    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
190    peft_kwargs: Optional[Dict] = None,
191    ignore_warnings: bool = True,
192    verify_n_labels_in_loader: Optional[int] = 50,
193    box_distortion_factor: Optional[float] = 0.025,
194    **model_kwargs,
195) -> None:
196    """Run training for a SAM model.
197
198    Args:
199        name: The name of the model to be trained. The checkpoint and logs will have this name.
200        model_type: The type of the SAM model.
201        train_loader: The dataloader for training.
202        val_loader: The dataloader for validation.
203        n_epochs: The number of epochs to train for.
204        early_stopping: Enable early stopping after this number of epochs without improvement.
205        n_objects_per_batch: The number of objects per batch used to compute
206            the loss for interative segmentation. If None all objects will be used,
207            if given objects will be randomly sub-sampled.
208        checkpoint_path: Path to checkpoint for initializing the SAM model.
209        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
210        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
211            By default nothing is frozen and the full model is updated.
212        device: The device to use for training.
213        lr: The learning rate.
214        n_sub_iteration: The number of iterative prompts per training iteration.
215        save_root: Optional root directory for saving the checkpoints and logs.
216            If not given the current working directory is used.
217        mask_prob: The probability for using a mask as input in a given training sub-iteration.
218        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
219        scheduler_class: The learning rate scheduler to update the learning rate.
220            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
221        scheduler_kwargs: The learning rate scheduler parameters.
222            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
223        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
224        pbar_signals: Controls for napari progress bar.
225        optimizer_class: The optimizer class.
226            By default, torch.optim.AdamW is used.
227        peft_kwargs: Keyword arguments for the PEFT wrapper class.
228        ignore_warnings: Whether to ignore raised warnings.
229        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
230            By default, 50 batches of labels are verified from the dataloaders.
231        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
232        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
233    """
234    with _filter_warnings(ignore_warnings):
235
236        t_start = time.time()
237
238        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
239        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
240
241        device = get_device(device)
242        # Get the trainable segment anything model.
243        model, state = get_trainable_sam_model(
244            model_type=model_type,
245            device=device,
246            freeze=freeze,
247            checkpoint_path=checkpoint_path,
248            return_state=True,
249            peft_kwargs=peft_kwargs,
250            **model_kwargs
251        )
252
253        # This class creates all the training data for a batch (inputs, prompts and labels).
254        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
255
256        # Create the UNETR decoder (if train with it) and the optimizer.
257        if with_segmentation_decoder:
258
259            # Get the UNETR.
260            unetr = get_unetr(
261                image_encoder=model.sam.image_encoder,
262                decoder_state=state.get("decoder_state", None),
263                device=device,
264            )
265
266            # Get the parameters for SAM and the decoder from UNETR.
267            joint_model_params = [params for params in model.parameters()]  # sam parameters
268            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
269                if not param_name.startswith("encoder"):
270                    joint_model_params.append(params)
271
272            model_params = joint_model_params
273        else:
274            model_params = model.parameters()
275
276        optimizer = optimizer_class(model_params, lr=lr)
277
278        if scheduler_kwargs is None:
279            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
280
281        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
282
283        # The trainer which performs training and validation.
284        if with_segmentation_decoder:
285            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
286            trainer = joint_trainers.JointSamTrainer(
287                name=name,
288                save_root=save_root,
289                train_loader=train_loader,
290                val_loader=val_loader,
291                model=model,
292                optimizer=optimizer,
293                device=device,
294                lr_scheduler=scheduler,
295                logger=joint_trainers.JointSamLogger,
296                log_image_interval=100,
297                mixed_precision=True,
298                convert_inputs=convert_inputs,
299                n_objects_per_batch=n_objects_per_batch,
300                n_sub_iteration=n_sub_iteration,
301                compile_model=False,
302                unetr=unetr,
303                instance_loss=instance_seg_loss,
304                instance_metric=instance_seg_loss,
305                early_stopping=early_stopping,
306                mask_prob=mask_prob,
307            )
308        else:
309            trainer = trainers.SamTrainer(
310                name=name,
311                train_loader=train_loader,
312                val_loader=val_loader,
313                model=model,
314                optimizer=optimizer,
315                device=device,
316                lr_scheduler=scheduler,
317                logger=trainers.SamLogger,
318                log_image_interval=100,
319                mixed_precision=True,
320                convert_inputs=convert_inputs,
321                n_objects_per_batch=n_objects_per_batch,
322                n_sub_iteration=n_sub_iteration,
323                compile_model=False,
324                early_stopping=early_stopping,
325                mask_prob=mask_prob,
326                save_root=save_root,
327            )
328
329        if n_iterations is None:
330            trainer_fit_params = {"epochs": n_epochs}
331        else:
332            trainer_fit_params = {"iterations": n_iterations}
333
334        if save_every_kth_epoch is not None:
335            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
336
337        if pbar_signals is not None:
338            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
339            trainer_fit_params["progress"] = progress_bar_wrapper
340
341        trainer.fit(**trainer_fit_params)
342
343        t_run = time.time() - t_start
344        hours = int(t_run // 3600)
345        minutes = int(t_run // 60)
346        seconds = int(round(t_run % 60, 0))
347        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
348
349
350def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels):
351    if isinstance(raw_paths, (str, os.PathLike)):
352        path = raw_paths
353    else:
354        path = raw_paths[0]
355    assert isinstance(path, (str, os.PathLike))
356
357    # Check the underlying data dimensionality.
358    if raw_key is None:  # If no key is given then we assume it's an image file.
359        ndim = imageio.imread(path).ndim
360    else:  # Otherwise we try to open the file from key.
361        try:  # First try to open it with elf.
362            with open_file(path, "r") as f:
363                ndim = f[raw_key].ndim
364        except ValueError:  # This may fail for images in a folder with different sizes.
365            # In that case we read one of the images.
366            image_path = glob(os.path.join(path, raw_key))[0]
367            ndim = imageio.imread(image_path).ndim
368
369    if not isinstance(patch_shape, tuple):
370        patch_shape = tuple(patch_shape)
371
372    if ndim == 2:
373        assert len(patch_shape) == 2
374        return patch_shape
375    elif ndim == 3 and len(patch_shape) == 2 and not with_channels:
376        return (1,) + patch_shape
377    elif ndim == 4 and len(patch_shape) == 2 and with_channels:
378        return (1,) + patch_shape
379    else:
380        return patch_shape
381
382
383def default_sam_dataset(
384    raw_paths: Union[List[FilePath], FilePath],
385    raw_key: Optional[str],
386    label_paths: Union[List[FilePath], FilePath],
387    label_key: Optional[str],
388    patch_shape: Tuple[int],
389    with_segmentation_decoder: bool,
390    with_channels: Optional[bool] = None,
391    sampler: Optional[Callable] = None,
392    raw_transform: Optional[Callable] = None,
393    n_samples: Optional[int] = None,
394    is_train: bool = True,
395    min_size: int = 25,
396    max_sampling_attempts: Optional[int] = None,
397    **kwargs,
398) -> Dataset:
399    """Create a PyTorch Dataset for training a SAM model.
400
401    Args:
402        raw_paths: The path(s) to the image data used for training.
403            Can either be multiple 2D images or volumetric data.
404        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
405            or a glob pattern for selecting multiple files.
406        label_paths: The path(s) to the label data used for training.
407            Can either be multiple 2D images or volumetric data.
408        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
409            or a glob pattern for selecting multiple files.
410        patch_shape: The shape for training patches.
411        with_segmentation_decoder: Whether to train with additional segmentation decoder.
412        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
413        sampler: A sampler to reject batches according to a given criterion.
414        raw_transform: Transformation applied to the image data.
415            If not given the data will be cast to 8bit.
416        n_samples: The number of samples for this dataset.
417        is_train: Whether this dataset is used for training or validation.
418        min_size: Minimal object size. Smaller objects will be filtered.
419        max_sampling_attempts: Number of sampling attempts to make from a dataset.
420        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
421
422    Returns:
423        The segmentation dataset.
424    """
425
426    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
427    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
428
429    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
430    # Get valid raw paths to make checks possible.
431    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
432        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
433    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
434        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
435
436    # Load one of the raw inputs to validate whether it is RGB or not.
437    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
438    if test_raw_inputs.ndim == 3:
439        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
440            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
441            # We need to provide a list of inputs to 'ImageCollectionDataset'.
442            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
443            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
444
445            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
446            with_channels = False if with_channels is None else with_channels
447
448        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
449            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
450            with_channels = True if with_channels is None else with_channels
451
452    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
453    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
454    with_channels = False if with_channels is None else with_channels
455
456    # Set the data transformations.
457    if raw_transform is None:
458        raw_transform = require_8bit
459
460    if with_segmentation_decoder:
461        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
462            distances=True,
463            boundary_distances=True,
464            directed_distances=False,
465            foreground=True,
466            instances=True,
467            min_size=min_size,
468        )
469    else:
470        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
471
472    # Set a default sampler if none was passed.
473    if sampler is None:
474        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
475
476    # Check the patch shape to add a singleton if required.
477    patch_shape = _update_patch_shape(
478        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
479    )
480
481    # Set a minimum number of samples per epoch.
482    if n_samples is None:
483        loader = torch_em.default_segmentation_loader(
484            raw_paths=raw_paths,
485            raw_key=raw_key,
486            label_paths=label_paths,
487            label_key=label_key,
488            batch_size=1,
489            patch_shape=patch_shape,
490            with_channels=with_channels,
491            ndim=2,
492            is_seg_dataset=is_seg_dataset,
493            raw_transform=raw_transform,
494            **kwargs
495        )
496        n_samples = max(len(loader), 100 if is_train else 5)
497
498    dataset = torch_em.default_segmentation_dataset(
499        raw_paths=raw_paths,
500        raw_key=raw_key,
501        label_paths=label_paths,
502        label_key=label_key,
503        patch_shape=patch_shape,
504        raw_transform=raw_transform,
505        label_transform=label_transform,
506        with_channels=with_channels,
507        ndim=2,
508        sampler=sampler,
509        n_samples=n_samples,
510        is_seg_dataset=is_seg_dataset,
511        **kwargs,
512    )
513
514    if max_sampling_attempts is not None:
515        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
516            for ds in dataset.datasets:
517                ds.max_sampling_attempts = max_sampling_attempts
518        else:
519            dataset.max_sampling_attempts = max_sampling_attempts
520
521    return dataset
522
523
524def default_sam_loader(**kwargs) -> DataLoader:
525    """Create a PyTorch DataLoader for training a SAM model.
526
527    Args:
528        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
529
530    Returns:
531        The DataLoader.
532    """
533    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
534
535    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
536    # which the users can provide to get their desired segmentation dataset.
537    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
538    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
539
540    ds = default_sam_dataset(**ds_kwargs)
541    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
542
543
544CONFIGURATIONS = {
545    "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4},
546    "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10},
547    "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5},
548    "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10},
549    "V100": {"model_type": "vit_b"},
550    "A100": {"model_type": "vit_h"},
551}
552
553
554def _find_best_configuration():
555    if torch.cuda.is_available():
556
557        # Check how much memory we have and select the best matching GPU
558        # for the available VRAM size.
559        _, vram = torch.cuda.mem_get_info()
560        vram = vram / 1e9  # in GB
561
562        # Maybe we can get more configurations in the future.
563        if vram > 80:  # More than 80 GB: use the A100 configurations.
564            return "A100"
565        elif vram > 30:  # More than 30 GB: use the V100 configurations.
566            return "V100"
567        elif vram > 14:  # More than 14 GB: use the RTX5000 configurations.
568            return "rtx5000"
569        else:  # Otherwise: not enough memory to train on the GPU, use CPU instead.
570            return "CPU"
571    else:
572        return "CPU"
573
574
575"""Best training configurations for given hardware resources.
576"""
577
578
579def train_sam_for_configuration(
580    name: str,
581    configuration: str,
582    train_loader: DataLoader,
583    val_loader: DataLoader,
584    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
585    with_segmentation_decoder: bool = True,
586    model_type: Optional[str] = None,
587    **kwargs,
588) -> None:
589    """Run training for a SAM model with the configuration for a given hardware resource.
590
591    Selects the best training settings for the given configuration.
592    The available configurations are listed in `CONFIGURATIONS`.
593
594    Args:
595        name: The name of the model to be trained.
596            The checkpoint and logs wil have this name.
597        configuration: The configuration (= name of hardware resource).
598        train_loader: The dataloader for training.
599        val_loader: The dataloader for validation.
600        checkpoint_path: Path to checkpoint for initializing the SAM model.
601        with_segmentation_decoder: Whether to train additional UNETR decoder
602            for automatic instance segmentation.
603        model_type: Over-ride the default model type.
604            This can be used to use one of the micro_sam models as starting point
605            instead of a default sam model.
606        kwargs: Additional keyword parameters that will be passed to `train_sam`.
607    """
608    if configuration in CONFIGURATIONS:
609        train_kwargs = CONFIGURATIONS[configuration]
610    else:
611        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
612
613    if model_type is None:
614        model_type = train_kwargs.pop("model_type")
615    else:
616        expected_model_type = train_kwargs.pop("model_type")
617        if model_type[:5] != expected_model_type:
618            warnings.warn("You have specified a different model type.")
619
620    train_kwargs.update(**kwargs)
621    train_sam(
622        name=name,
623        train_loader=train_loader,
624        val_loader=val_loader,
625        checkpoint_path=checkpoint_path,
626        with_segmentation_decoder=with_segmentation_decoder,
627        model_type=model_type,
628        **train_kwargs
629    )
630
631
632def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader):
633
634    # Whether the model is stored in the current working directory or in another location.
635    if save_root is None:
636        save_root = os.getcwd()  # Map this to current working directory, if not specified by the user.
637
638    # Get the 'best' model checkpoint ready for export.
639    best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt")
640    if not os.path.exists(best_checkpoint):
641        raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.")
642
643    # Export the model if an output path has been given.
644    if output_path:
645
646        # If the filepath has a pytorch-specific ending, then we just export the checkpoint.
647        if os.path.splitext(output_path)[1] in (".pt", ".pth"):
648            export_custom_sam_model(
649                checkpoint_path=best_checkpoint,
650                model_type=model_type[:5],
651                save_path=output_path,
652                with_segmentation_decoder=with_segmentation_decoder,
653            )
654
655        # Otherwise we export it as bioimage.io model.
656        else:
657            from micro_sam.bioimageio import export_sam_model
658
659            # Load image and corresponding labels from the val loader.
660            with torch.no_grad():
661                image_data, label_data = next(iter(val_loader))
662                image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze()
663
664                # Select the first channel of the label image if we have a channel axis, i.e. contains the labels
665                if label_data.ndim == 3:
666                    label_data = label_data[0]  # Gets the channel with instances.
667                assert image_data.shape == label_data.shape
668                label_data = label_data.astype("uint32")
669
670                export_sam_model(
671                    image=image_data,
672                    label_image=label_data,
673                    model_type=model_type[:5],
674                    name=checkpoint_name,
675                    output_path=output_path,
676                    checkpoint_path=best_checkpoint,
677                )
678
679        # The final path where the model has been stored.
680        final_path = output_path
681
682    else:  # If no exports have been made, inform the user about the best checkpoint.
683        final_path = best_checkpoint
684
685    return final_path
686
687
688def main():
689    """@private"""
690    import argparse
691
692    available_models = list(get_model_names())
693    available_models = ", ".join(available_models)
694
695    available_configurations = list(CONFIGURATIONS.keys())
696    available_configurations = ", ".join(available_configurations)
697
698    parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.")
699
700    # Images and labels for training.
701    parser.add_argument(
702        "--images", required=True, type=str, nargs="*",
703        help="Filepath to images or the directory where the image data is stored."
704    )
705    parser.add_argument(
706        "--labels", required=True, type=str, nargs="*",
707        help="Filepath to ground-truth labels or the directory where the label data is stored."
708    )
709    parser.add_argument(
710        "--image_key", type=str, default=None,
711        help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. "
712    )
713    parser.add_argument(
714        "--label_key", type=str, default=None,
715        help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. "
716    )
717
718    # Images and labels for validation.
719    # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided.
720    # Users can choose to have their explicit validation set via this feature as well.
721    parser.add_argument(
722        "--val_images", type=str, nargs="*",
723        help="Filepath to images for validation or the directory where the image data is stored."
724    )
725    parser.add_argument(
726        "--val_labels", type=str, nargs="*",
727        help="Filepath to ground-truth labels for validation or the directory where the label data is stored."
728    )
729    parser.add_argument(
730        "--val_image_key", type=str, default=None,
731        help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file."
732    )
733    parser.add_argument(
734        "--val_label_key", type=str, default=None,
735        help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file."
736    )
737
738    # Other necessary stuff for training.
739    parser.add_argument(
740        "--configuration", type=str, default=_find_best_configuration(),
741        help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}."
742    )
743    parser.add_argument(
744        "--segmentation_decoder", type=str, default="instances",  # TODO: in future, we can extend this to semantic seg.
745        help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. "
746        "By default, it trains with the additional segmentation decoder for instance segmentation."
747    )
748
749    # Optional advanced settings a user can opt to change the values for.
750    parser.add_argument(
751        "-d", "--device", type=str, default=None,
752        help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). "
753        "By default the most performant available device will be selected."
754    )
755    parser.add_argument(
756        "--patch_shape", type=int, nargs="*", default=(512, 512),
757        help="The choice of patch shape for training Segment Anything."
758    )
759    parser.add_argument(
760        "-m", "--model_type", type=str, default=None,
761        help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}."
762    )
763    parser.add_argument(
764        "--checkpoint_path", type=str, default=None,
765        help="Checkpoint from which the SAM model will be loaded for finetuning."
766    )
767    parser.add_argument(
768        "-s", "--save_root", type=str, default=None,
769        help="The directory where the trained models and corresponding logs will be stored. "
770        "By default, there are stored in your current working directory."
771    )
772    parser.add_argument(
773        "--trained_model_name", type=str, default="sam_model",
774        help="The custom name of trained model. Allows users to have several trained models under the same 'save_root'."
775    )
776    parser.add_argument(
777        "--output_path", type=str, default=None,
778        help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model."
779    )
780    parser.add_argument(
781        "--n_epochs", type=int, default=100,
782        help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs."
783    )
784    parser.add_argument(
785        "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders."
786    )
787    parser.add_argument(
788        "--batch_size", type=int, default=1,
789        help="The choice of batch size for training the Segment Anything Model. By default, trains on batch size 1."
790    )
791    parser.add_argument(
792        "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"),
793        help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images "
794        "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'."
795    )
796
797    args = parser.parse_args()
798
799    # 1. Get all necessary stuff for training.
800    checkpoint_name = args.trained_model_name
801    config = args.configuration
802    model_type = args.model_type
803    checkpoint_path = args.checkpoint_path
804    batch_size = args.batch_size
805    patch_shape = args.patch_shape
806    epochs = args.n_epochs
807    num_workers = args.num_workers
808    device = args.device
809    save_root = args.save_root
810    output_path = args.output_path
811    with_segmentation_decoder = (args.segmentation_decoder == "instances")
812
813    # Get image paths and corresponding keys.
814    train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key
815    val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key  # noqa
816
817    # 2. Prepare the dataloaders.
818
819    # If the user wants to preprocess the inputs, we allow the possibility to do so.
820    _raw_transform = get_raw_transform(args.preprocess)
821
822    # Get the dataset with files for training.
823    dataset = default_sam_dataset(
824        raw_paths=train_images,
825        raw_key=train_image_key,
826        label_paths=train_gt,
827        label_key=train_gt_key,
828        patch_shape=patch_shape,
829        with_segmentation_decoder=with_segmentation_decoder,
830        raw_transform=_raw_transform,
831    )
832
833    # If val images are not exclusively provided, we create a val split from the training data.
834    if val_images is None:
835        assert val_gt is None and val_image_key is None and val_gt_key is None
836        # Use 10% of the dataset for validation - at least one image - for validation.
837        n_val = max(1, int(0.1 * len(dataset)))
838        train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val])
839
840    else:  # If val images provided, we create a new dataset for it.
841        train_dataset = dataset
842        val_dataset = default_sam_dataset(
843            raw_paths=val_images,
844            raw_key=val_image_key,
845            label_paths=val_gt,
846            label_key=val_gt_key,
847            patch_shape=patch_shape,
848            with_segmentation_decoder=with_segmentation_decoder,
849            raw_transform=_raw_transform,
850        )
851
852    # Get the dataloaders from the datasets.
853    train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
854    val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers)
855
856    # 3. Train the Segment Anything Model.
857
858    # Get a valid model and other necessary parameters for training.
859    if model_type is not None and model_type not in available_models:
860        raise ValueError(f"'{model_type}' is not a valid choice of model.")
861    if config is not None and config not in available_configurations:
862        raise ValueError(f"'{config}' is not a valid choice of configuration.")
863
864    if model_type is None:  # If user does not specify the model, we use the default model corresponding to the config.
865        model_type = CONFIGURATIONS[config]["model_type"]
866
867    train_sam_for_configuration(
868        name=checkpoint_name,
869        configuration=config,
870        model_type=model_type,
871        train_loader=train_loader,
872        val_loader=val_loader,
873        n_epochs=epochs,
874        checkpoint_path=checkpoint_path,
875        with_segmentation_decoder=with_segmentation_decoder,
876        freeze=None,  # TODO: Allow for PEFT.
877        device=device,
878        save_root=save_root,
879        peft_kwargs=None,  # TODO: Allow for PEFT.
880    )
881
882    # 4. Export the model, if desired by the user
883    final_path = _export_helper(
884        save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader
885    )
886
887    print(f"Training has finished. The trained model is saved at {final_path}.")
FilePath = typing.Union[str, os.PathLike]
def train_sam( name: str, model_type: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, n_epochs: int = 100, early_stopping: Optional[int] = 10, n_objects_per_batch: Optional[int] = 25, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, freeze: Optional[List[str]] = None, device: Union[str, torch.device, NoneType] = None, lr: float = 1e-05, n_sub_iteration: int = 8, save_root: Union[os.PathLike, str, NoneType] = None, mask_prob: float = 0.5, n_iterations: Optional[int] = None, scheduler_class: Optional[torch.optim.lr_scheduler._LRScheduler] = <class 'torch.optim.lr_scheduler.ReduceLROnPlateau'>, scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[PyQt5.QtCore.QObject] = None, optimizer_class: Optional[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, verify_n_labels_in_loader: Optional[int] = 50, box_distortion_factor: Optional[float] = 0.025, **model_kwargs) -> None:
169def train_sam(
170    name: str,
171    model_type: str,
172    train_loader: DataLoader,
173    val_loader: DataLoader,
174    n_epochs: int = 100,
175    early_stopping: Optional[int] = 10,
176    n_objects_per_batch: Optional[int] = 25,
177    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
178    with_segmentation_decoder: bool = True,
179    freeze: Optional[List[str]] = None,
180    device: Optional[Union[str, torch.device]] = None,
181    lr: float = 1e-5,
182    n_sub_iteration: int = 8,
183    save_root: Optional[Union[str, os.PathLike]] = None,
184    mask_prob: float = 0.5,
185    n_iterations: Optional[int] = None,
186    scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau,
187    scheduler_kwargs: Optional[Dict[str, Any]] = None,
188    save_every_kth_epoch: Optional[int] = None,
189    pbar_signals: Optional[QObject] = None,
190    optimizer_class: Optional[Optimizer] = torch.optim.AdamW,
191    peft_kwargs: Optional[Dict] = None,
192    ignore_warnings: bool = True,
193    verify_n_labels_in_loader: Optional[int] = 50,
194    box_distortion_factor: Optional[float] = 0.025,
195    **model_kwargs,
196) -> None:
197    """Run training for a SAM model.
198
199    Args:
200        name: The name of the model to be trained. The checkpoint and logs will have this name.
201        model_type: The type of the SAM model.
202        train_loader: The dataloader for training.
203        val_loader: The dataloader for validation.
204        n_epochs: The number of epochs to train for.
205        early_stopping: Enable early stopping after this number of epochs without improvement.
206        n_objects_per_batch: The number of objects per batch used to compute
207            the loss for interative segmentation. If None all objects will be used,
208            if given objects will be randomly sub-sampled.
209        checkpoint_path: Path to checkpoint for initializing the SAM model.
210        with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
211        freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
212            By default nothing is frozen and the full model is updated.
213        device: The device to use for training.
214        lr: The learning rate.
215        n_sub_iteration: The number of iterative prompts per training iteration.
216        save_root: Optional root directory for saving the checkpoints and logs.
217            If not given the current working directory is used.
218        mask_prob: The probability for using a mask as input in a given training sub-iteration.
219        n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
220        scheduler_class: The learning rate scheduler to update the learning rate.
221            By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
222        scheduler_kwargs: The learning rate scheduler parameters.
223            If passed None, the chosen default parameters are used in ReduceLROnPlateau.
224        save_every_kth_epoch: Save checkpoints after every kth epoch separately.
225        pbar_signals: Controls for napari progress bar.
226        optimizer_class: The optimizer class.
227            By default, torch.optim.AdamW is used.
228        peft_kwargs: Keyword arguments for the PEFT wrapper class.
229        ignore_warnings: Whether to ignore raised warnings.
230        verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders.
231            By default, 50 batches of labels are verified from the dataloaders.
232        box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
233        model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
234    """
235    with _filter_warnings(ignore_warnings):
236
237        t_start = time.time()
238
239        _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader)
240        _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader)
241
242        device = get_device(device)
243        # Get the trainable segment anything model.
244        model, state = get_trainable_sam_model(
245            model_type=model_type,
246            device=device,
247            freeze=freeze,
248            checkpoint_path=checkpoint_path,
249            return_state=True,
250            peft_kwargs=peft_kwargs,
251            **model_kwargs
252        )
253
254        # This class creates all the training data for a batch (inputs, prompts and labels).
255        convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor)
256
257        # Create the UNETR decoder (if train with it) and the optimizer.
258        if with_segmentation_decoder:
259
260            # Get the UNETR.
261            unetr = get_unetr(
262                image_encoder=model.sam.image_encoder,
263                decoder_state=state.get("decoder_state", None),
264                device=device,
265            )
266
267            # Get the parameters for SAM and the decoder from UNETR.
268            joint_model_params = [params for params in model.parameters()]  # sam parameters
269            for param_name, params in unetr.named_parameters():  # unetr's decoder parameters
270                if not param_name.startswith("encoder"):
271                    joint_model_params.append(params)
272
273            model_params = joint_model_params
274        else:
275            model_params = model.parameters()
276
277        optimizer = optimizer_class(model_params, lr=lr)
278
279        if scheduler_kwargs is None:
280            scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}
281
282        scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs)
283
284        # The trainer which performs training and validation.
285        if with_segmentation_decoder:
286            instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
287            trainer = joint_trainers.JointSamTrainer(
288                name=name,
289                save_root=save_root,
290                train_loader=train_loader,
291                val_loader=val_loader,
292                model=model,
293                optimizer=optimizer,
294                device=device,
295                lr_scheduler=scheduler,
296                logger=joint_trainers.JointSamLogger,
297                log_image_interval=100,
298                mixed_precision=True,
299                convert_inputs=convert_inputs,
300                n_objects_per_batch=n_objects_per_batch,
301                n_sub_iteration=n_sub_iteration,
302                compile_model=False,
303                unetr=unetr,
304                instance_loss=instance_seg_loss,
305                instance_metric=instance_seg_loss,
306                early_stopping=early_stopping,
307                mask_prob=mask_prob,
308            )
309        else:
310            trainer = trainers.SamTrainer(
311                name=name,
312                train_loader=train_loader,
313                val_loader=val_loader,
314                model=model,
315                optimizer=optimizer,
316                device=device,
317                lr_scheduler=scheduler,
318                logger=trainers.SamLogger,
319                log_image_interval=100,
320                mixed_precision=True,
321                convert_inputs=convert_inputs,
322                n_objects_per_batch=n_objects_per_batch,
323                n_sub_iteration=n_sub_iteration,
324                compile_model=False,
325                early_stopping=early_stopping,
326                mask_prob=mask_prob,
327                save_root=save_root,
328            )
329
330        if n_iterations is None:
331            trainer_fit_params = {"epochs": n_epochs}
332        else:
333            trainer_fit_params = {"iterations": n_iterations}
334
335        if save_every_kth_epoch is not None:
336            trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch
337
338        if pbar_signals is not None:
339            progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
340            trainer_fit_params["progress"] = progress_bar_wrapper
341
342        trainer.fit(**trainer_fit_params)
343
344        t_run = time.time() - t_start
345        hours = int(t_run // 3600)
346        minutes = int(t_run // 60)
347        seconds = int(round(t_run % 60, 0))
348        print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")

Run training for a SAM model.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs will have this name.
  • model_type: The type of the SAM model.
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • n_epochs: The number of epochs to train for.
  • early_stopping: Enable early stopping after this number of epochs without improvement.
  • n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
  • device: The device to use for training.
  • lr: The learning rate.
  • n_sub_iteration: The number of iterative prompts per training iteration.
  • save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
  • mask_prob: The probability for using a mask as input in a given training sub-iteration.
  • n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given.
  • scheduler_class: The learning rate scheduler to update the learning rate. By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used.
  • scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau.
  • save_every_kth_epoch: Save checkpoints after every kth epoch separately.
  • pbar_signals: Controls for napari progress bar.
  • optimizer_class: The optimizer class. By default, torch.optim.AdamW is used.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
  • ignore_warnings: Whether to ignore raised warnings.
  • verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
  • box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks.
  • model_kwargs: Additional keyword arguments for the util.get_sam_model.
def default_sam_dataset( raw_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], raw_key: Optional[str], label_paths: Union[List[Union[str, os.PathLike]], str, os.PathLike], label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, **kwargs) -> torch.utils.data.dataset.Dataset:
384def default_sam_dataset(
385    raw_paths: Union[List[FilePath], FilePath],
386    raw_key: Optional[str],
387    label_paths: Union[List[FilePath], FilePath],
388    label_key: Optional[str],
389    patch_shape: Tuple[int],
390    with_segmentation_decoder: bool,
391    with_channels: Optional[bool] = None,
392    sampler: Optional[Callable] = None,
393    raw_transform: Optional[Callable] = None,
394    n_samples: Optional[int] = None,
395    is_train: bool = True,
396    min_size: int = 25,
397    max_sampling_attempts: Optional[int] = None,
398    **kwargs,
399) -> Dataset:
400    """Create a PyTorch Dataset for training a SAM model.
401
402    Args:
403        raw_paths: The path(s) to the image data used for training.
404            Can either be multiple 2D images or volumetric data.
405        raw_key: The key for accessing the image data. Internal filepath for hdf5-like input
406            or a glob pattern for selecting multiple files.
407        label_paths: The path(s) to the label data used for training.
408            Can either be multiple 2D images or volumetric data.
409        label_key: The key for accessing the label data. Internal filepath for hdf5-like input
410            or a glob pattern for selecting multiple files.
411        patch_shape: The shape for training patches.
412        with_segmentation_decoder: Whether to train with additional segmentation decoder.
413        with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
414        sampler: A sampler to reject batches according to a given criterion.
415        raw_transform: Transformation applied to the image data.
416            If not given the data will be cast to 8bit.
417        n_samples: The number of samples for this dataset.
418        is_train: Whether this dataset is used for training or validation.
419        min_size: Minimal object size. Smaller objects will be filtered.
420        max_sampling_attempts: Number of sampling attempts to make from a dataset.
421        kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
422
423    Returns:
424        The segmentation dataset.
425    """
426
427    # By default, let the 'default_segmentation_dataset' heuristic decide for itself.
428    is_seg_dataset = kwargs.pop("is_seg_dataset", None)
429
430    # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'.
431    # Get valid raw paths to make checks possible.
432    if raw_key and "*" in raw_key:  # Use the wildcard pattern to find the filepath to only one image.
433        rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0]
434    else:  # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath.
435        rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0]
436
437    # Load one of the raw inputs to validate whether it is RGB or not.
438    test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None)
439    if test_raw_inputs.ndim == 3:
440        if test_raw_inputs.shape[-1] == 3:  # i.e. if it is an RGB image and has channels last.
441            is_seg_dataset = False  # we use 'ImageCollectionDataset' in this case.
442            # We need to provide a list of inputs to 'ImageCollectionDataset'.
443            raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths
444            label_paths = [label_paths] if isinstance(label_paths, str) else label_paths
445
446            # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'.
447            with_channels = False if with_channels is None else with_channels
448
449        elif test_raw_inputs.shape[0] == 3:  # i.e. if it is a RGB image and has 3 channels first.
450            # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'.
451            with_channels = True if with_channels is None else with_channels
452
453    # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset'
454    # Otherwise, let the user make the choice as priority, else set this to our suggested default.
455    with_channels = False if with_channels is None else with_channels
456
457    # Set the data transformations.
458    if raw_transform is None:
459        raw_transform = require_8bit
460
461    if with_segmentation_decoder:
462        label_transform = torch_em.transform.label.PerObjectDistanceTransform(
463            distances=True,
464            boundary_distances=True,
465            directed_distances=False,
466            foreground=True,
467            instances=True,
468            min_size=min_size,
469        )
470    else:
471        label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size)
472
473    # Set a default sampler if none was passed.
474    if sampler is None:
475        sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size)
476
477    # Check the patch shape to add a singleton if required.
478    patch_shape = _update_patch_shape(
479        patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels,
480    )
481
482    # Set a minimum number of samples per epoch.
483    if n_samples is None:
484        loader = torch_em.default_segmentation_loader(
485            raw_paths=raw_paths,
486            raw_key=raw_key,
487            label_paths=label_paths,
488            label_key=label_key,
489            batch_size=1,
490            patch_shape=patch_shape,
491            with_channels=with_channels,
492            ndim=2,
493            is_seg_dataset=is_seg_dataset,
494            raw_transform=raw_transform,
495            **kwargs
496        )
497        n_samples = max(len(loader), 100 if is_train else 5)
498
499    dataset = torch_em.default_segmentation_dataset(
500        raw_paths=raw_paths,
501        raw_key=raw_key,
502        label_paths=label_paths,
503        label_key=label_key,
504        patch_shape=patch_shape,
505        raw_transform=raw_transform,
506        label_transform=label_transform,
507        with_channels=with_channels,
508        ndim=2,
509        sampler=sampler,
510        n_samples=n_samples,
511        is_seg_dataset=is_seg_dataset,
512        **kwargs,
513    )
514
515    if max_sampling_attempts is not None:
516        if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset):
517            for ds in dataset.datasets:
518                ds.max_sampling_attempts = max_sampling_attempts
519        else:
520            dataset.max_sampling_attempts = max_sampling_attempts
521
522    return dataset

Create a PyTorch Dataset for training a SAM model.

Arguments:
  • raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data.
  • raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data.
  • label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files.
  • patch_shape: The shape for training patches.
  • with_segmentation_decoder: Whether to train with additional segmentation decoder.
  • with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
  • sampler: A sampler to reject batches according to a given criterion.
  • raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
  • n_samples: The number of samples for this dataset.
  • is_train: Whether this dataset is used for training or validation.
  • min_size: Minimal object size. Smaller objects will be filtered.
  • max_sampling_attempts: Number of sampling attempts to make from a dataset.
  • kwargs: Additional keyword arguments for torch_em.default_segmentation_dataset.
Returns:

The segmentation dataset.

def default_sam_loader(**kwargs) -> torch.utils.data.dataloader.DataLoader:
525def default_sam_loader(**kwargs) -> DataLoader:
526    """Create a PyTorch DataLoader for training a SAM model.
527
528    Args:
529        kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader.
530
531    Returns:
532        The DataLoader.
533    """
534    sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs)
535
536    # There might be additional parameters supported by `torch_em.default_segmentation_dataset`,
537    # which the users can provide to get their desired segmentation dataset.
538    extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs)
539    ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs}
540
541    ds = default_sam_dataset(**ds_kwargs)
542    return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)

Create a PyTorch DataLoader for training a SAM model.

Arguments:
  • kwargs: Keyword arguments for micro_sam.training.default_sam_dataset or for the PyTorch DataLoader.
Returns:

The DataLoader.

CONFIGURATIONS = {'Minimal': {'model_type': 'vit_t', 'n_objects_per_batch': 4, 'n_sub_iteration': 4}, 'CPU': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'gtx1080': {'model_type': 'vit_t', 'n_objects_per_batch': 5}, 'rtx5000': {'model_type': 'vit_b', 'n_objects_per_batch': 10}, 'V100': {'model_type': 'vit_b'}, 'A100': {'model_type': 'vit_h'}}
def train_sam_for_configuration( name: str, configuration: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, checkpoint_path: Union[os.PathLike, str, NoneType] = None, with_segmentation_decoder: bool = True, model_type: Optional[str] = None, **kwargs) -> None:
580def train_sam_for_configuration(
581    name: str,
582    configuration: str,
583    train_loader: DataLoader,
584    val_loader: DataLoader,
585    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
586    with_segmentation_decoder: bool = True,
587    model_type: Optional[str] = None,
588    **kwargs,
589) -> None:
590    """Run training for a SAM model with the configuration for a given hardware resource.
591
592    Selects the best training settings for the given configuration.
593    The available configurations are listed in `CONFIGURATIONS`.
594
595    Args:
596        name: The name of the model to be trained.
597            The checkpoint and logs wil have this name.
598        configuration: The configuration (= name of hardware resource).
599        train_loader: The dataloader for training.
600        val_loader: The dataloader for validation.
601        checkpoint_path: Path to checkpoint for initializing the SAM model.
602        with_segmentation_decoder: Whether to train additional UNETR decoder
603            for automatic instance segmentation.
604        model_type: Over-ride the default model type.
605            This can be used to use one of the micro_sam models as starting point
606            instead of a default sam model.
607        kwargs: Additional keyword parameters that will be passed to `train_sam`.
608    """
609    if configuration in CONFIGURATIONS:
610        train_kwargs = CONFIGURATIONS[configuration]
611    else:
612        raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")
613
614    if model_type is None:
615        model_type = train_kwargs.pop("model_type")
616    else:
617        expected_model_type = train_kwargs.pop("model_type")
618        if model_type[:5] != expected_model_type:
619            warnings.warn("You have specified a different model type.")
620
621    train_kwargs.update(**kwargs)
622    train_sam(
623        name=name,
624        train_loader=train_loader,
625        val_loader=val_loader,
626        checkpoint_path=checkpoint_path,
627        with_segmentation_decoder=with_segmentation_decoder,
628        model_type=model_type,
629        **train_kwargs
630    )

Run training for a SAM model with the configuration for a given hardware resource.

Selects the best training settings for the given configuration. The available configurations are listed in CONFIGURATIONS.

Arguments:
  • name: The name of the model to be trained. The checkpoint and logs wil have this name.
  • configuration: The configuration (= name of hardware resource).
  • train_loader: The dataloader for training.
  • val_loader: The dataloader for validation.
  • checkpoint_path: Path to checkpoint for initializing the SAM model.
  • with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation.
  • model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
  • kwargs: Additional keyword parameters that will be passed to train_sam.