micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from collections import OrderedDict 7from contextlib import contextmanager, nullcontext 8from typing import Any, Callable, Dict, List, Optional, Tuple, Union 9 10import imageio.v3 as imageio 11import numpy as np 12import torch 13from torch.optim import Optimizer 14from torch.utils.data import random_split 15from torch.utils.data import DataLoader, Dataset 16from torch.optim.lr_scheduler import _LRScheduler 17 18import torch_em 19from torch_em.util import load_data 20from torch_em.data.datasets.util import split_kwargs 21 22from elf.io import open_file 23 24try: 25 from qtpy.QtCore import QObject 26except Exception: 27 QObject = Any 28 29from . import sam_trainer as trainers 30from ..instance_segmentation import get_unetr 31from ..models.peft_sam import ClassicalSurgery 32from . import joint_sam_trainer as joint_trainers 33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model 34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform 35 36 37FilePath = Union[str, os.PathLike] 38 39 40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 41 x, _ = next(iter(loader)) 42 43 # Raw data: check that we have 1 or 3 channels. 44 n_channels = x.shape[1] 45 if n_channels not in (1, 3): 46 raise ValueError( 47 "Invalid number of channels for the input data from the data loader. " 48 f"Expect 1 or 3 channels, got {n_channels}." 49 ) 50 51 # Raw data: check that it is between [0, 255] 52 minval, maxval = x.min(), x.max() 53 if minval < 0 or minval > 255: 54 raise ValueError( 55 "Invalid data range for the input data from the data loader. " 56 f"The input has to be in range [0, 255], but got minimum value {minval}." 57 ) 58 if maxval < 1 or maxval > 255: 59 raise ValueError( 60 "Invalid data range for the input data from the data loader. " 61 f"The input has to be in range [0, 255], but got maximum value {maxval}." 62 ) 63 64 # Target data: the check depends on whether we train with or without decoder. 65 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 66 67 def _check_instance_channel(instance_channel): 68 unique_vals = torch.unique(instance_channel) 69 if (unique_vals < 0).any(): 70 raise ValueError( 71 "The target channel with the instance segmentation must not have negative values." 72 ) 73 if len(unique_vals) == 1: 74 raise ValueError( 75 "The target channel with the instance segmentation must have at least one instance." 76 ) 77 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 78 raise ValueError( 79 "All values in the target channel with the instance segmentation must be integer." 80 ) 81 82 counter = 0 83 name = "" if name is None else f"'{name}'" 84 for x, y in tqdm( 85 loader, 86 desc=f"Verifying labels in {name} dataloader", 87 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 88 ): 89 n_channels_y = y.shape[1] 90 if with_segmentation_decoder: 91 if n_channels_y != 4: 92 raise ValueError( 93 "Invalid number of channels in the target data from the data loader. " 94 "Expect 4 channel for training with an instance segmentation decoder, " 95 f"but got {n_channels_y} channels." 96 ) 97 # Check instance channel per sample in a batch 98 for per_y_sample in y: 99 _check_instance_channel(per_y_sample[0]) 100 101 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 102 if targets_min < 0 or targets_min > 1: 103 raise ValueError( 104 "Invalid value range in the target data from the value loader. " 105 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 106 f"to be in range [0.0, 1.0], but got min {targets_min}" 107 ) 108 if targets_max < 0 or targets_max > 1: 109 raise ValueError( 110 "Invalid value range in the target data from the value loader. " 111 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 112 f"to be in range [0.0, 1.0], but got max {targets_max}" 113 ) 114 115 else: 116 if n_channels_y != 1: 117 raise ValueError( 118 "Invalid number of channels in the target data from the data loader. " 119 "Expect 1 channel for training without an instance segmentation decoder, " 120 f"but got {n_channels_y} channels." 121 ) 122 # Check instance channel per sample in a batch 123 for per_y_sample in y: 124 _check_instance_channel(per_y_sample) 125 126 counter += 1 127 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 128 break 129 130 131# Make the progress bar callbacks compatible with a tqdm progress bar interface. 132class _ProgressBarWrapper: 133 def __init__(self, signals): 134 self._signals = signals 135 self._total = None 136 137 @property 138 def total(self): 139 return self._total 140 141 @total.setter 142 def total(self, value): 143 self._signals.pbar_total.emit(value) 144 self._total = value 145 146 def update(self, steps): 147 self._signals.pbar_update.emit(steps) 148 149 def set_description(self, desc, **kwargs): 150 self._signals.pbar_description.emit(desc) 151 152 153@contextmanager 154def _filter_warnings(ignore_warnings): 155 if ignore_warnings: 156 with warnings.catch_warnings(): 157 warnings.simplefilter("ignore") 158 yield 159 else: 160 with nullcontext(): 161 yield 162 163 164def _count_parameters(model_parameters): 165 params = sum(p.numel() for p in model_parameters if p.requires_grad) 166 params = params / 1e6 167 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 168 169 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training): 171 if n_iterations is None: 172 trainer_fit_params = {"epochs": n_epochs} 173 else: 174 trainer_fit_params = {"iterations": n_iterations} 175 176 if save_every_kth_epoch is not None: 177 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 178 179 if pbar_signals is not None: 180 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 181 trainer_fit_params["progress"] = progress_bar_wrapper 182 183 # Avoid overwriting a trained model, if desired by the user. 184 trainer_fit_params["overwrite_training"] = overwrite_training 185 return trainer_fit_params 186 187 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): 189 optimizer = optimizer_class(model_params, lr=lr) 190 if scheduler_kwargs is None: 191 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3} 192 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 193 return optimizer, scheduler 194 195 196def train_sam( 197 name: str, 198 model_type: str, 199 train_loader: DataLoader, 200 val_loader: DataLoader, 201 n_epochs: int = 100, 202 early_stopping: Optional[int] = 10, 203 n_objects_per_batch: Optional[int] = 25, 204 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 205 with_segmentation_decoder: bool = True, 206 freeze: Optional[List[str]] = None, 207 device: Optional[Union[str, torch.device]] = None, 208 lr: float = 1e-5, 209 n_sub_iteration: int = 8, 210 save_root: Optional[Union[str, os.PathLike]] = None, 211 mask_prob: float = 0.5, 212 n_iterations: Optional[int] = None, 213 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 214 scheduler_kwargs: Optional[Dict[str, Any]] = None, 215 save_every_kth_epoch: Optional[int] = None, 216 pbar_signals: Optional[QObject] = None, 217 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 218 peft_kwargs: Optional[Dict] = None, 219 ignore_warnings: bool = True, 220 verify_n_labels_in_loader: Optional[int] = 50, 221 box_distortion_factor: Optional[float] = 0.025, 222 overwrite_training: bool = True, 223 **model_kwargs, 224) -> None: 225 """Run training for a SAM model. 226 227 Args: 228 name: The name of the model to be trained. The checkpoint and logs will have this name. 229 model_type: The type of the SAM model. 230 train_loader: The dataloader for training. 231 val_loader: The dataloader for validation. 232 n_epochs: The number of epochs to train for. 233 early_stopping: Enable early stopping after this number of epochs without improvement. 234 By default, the value is set to '10' epochs. 235 n_objects_per_batch: The number of objects per batch used to compute 236 the loss for interative segmentation. If None all objects will be used, 237 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 238 checkpoint_path: Path to checkpoint for initializing the SAM model. 239 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 240 By default, trains with the additional instance segmentation decoder. 241 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 242 By default nothing is frozen and the full model is updated. 243 device: The device to use for training. By default, automatically chooses the best available device to train. 244 lr: The learning rate. By default, set to '1e-5'. 245 n_sub_iteration: The number of iterative prompts per training iteration. 246 By default, the number of iterations is set to '8'. 247 save_root: Optional root directory for saving the checkpoints and logs. 248 If not given the current working directory is used. 249 mask_prob: The probability for using a mask as input in a given training sub-iteration. 250 By default, set to '0.5'. 251 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 252 scheduler_class: The learning rate scheduler to update the learning rate. 253 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 254 scheduler_kwargs: The learning rate scheduler parameters. 255 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 256 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 257 pbar_signals: Controls for napari progress bar. 258 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 259 peft_kwargs: Keyword arguments for the PEFT wrapper class. 260 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 261 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 262 By default, 50 batches of labels are verified from the dataloaders. 263 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 264 By default, the distortion factor is set to '0.025'. 265 overwrite_training: Whether to overwrite the trained model stored at the same location. 266 By default, overwrites the trained model at each run. 267 If set to 'False', it will avoid retraining the model if the previous run was completed. 268 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 269 """ 270 with _filter_warnings(ignore_warnings): 271 272 t_start = time.time() 273 if verify_n_labels_in_loader is not None: 274 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 275 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 276 277 device = get_device(device) 278 # Get the trainable segment anything model. 279 model, state = get_trainable_sam_model( 280 model_type=model_type, 281 device=device, 282 freeze=freeze, 283 checkpoint_path=checkpoint_path, 284 return_state=True, 285 peft_kwargs=peft_kwargs, 286 **model_kwargs 287 ) 288 289 # This class creates all the training data for a batch (inputs, prompts and labels). 290 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 291 292 # Create the UNETR decoder (if train with it) and the optimizer. 293 if with_segmentation_decoder: 294 295 # Get the UNETR. 296 unetr = get_unetr( 297 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 298 ) 299 300 # Get the parameters for SAM and the decoder from UNETR. 301 joint_model_params = [params for params in model.parameters()] # sam parameters 302 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 303 if not param_name.startswith("encoder"): 304 joint_model_params.append(params) 305 306 model_params = joint_model_params 307 else: 308 model_params = model.parameters() 309 310 optimizer, scheduler = _get_optimizer_and_scheduler( 311 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 312 ) 313 314 # The trainer which performs training and validation. 315 if with_segmentation_decoder: 316 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 317 trainer = joint_trainers.JointSamTrainer( 318 name=name, 319 save_root=save_root, 320 train_loader=train_loader, 321 val_loader=val_loader, 322 model=model, 323 optimizer=optimizer, 324 device=device, 325 lr_scheduler=scheduler, 326 logger=joint_trainers.JointSamLogger, 327 log_image_interval=100, 328 mixed_precision=True, 329 convert_inputs=convert_inputs, 330 n_objects_per_batch=n_objects_per_batch, 331 n_sub_iteration=n_sub_iteration, 332 compile_model=False, 333 unetr=unetr, 334 instance_loss=instance_seg_loss, 335 instance_metric=instance_seg_loss, 336 early_stopping=early_stopping, 337 mask_prob=mask_prob, 338 ) 339 else: 340 trainer = trainers.SamTrainer( 341 name=name, 342 train_loader=train_loader, 343 val_loader=val_loader, 344 model=model, 345 optimizer=optimizer, 346 device=device, 347 lr_scheduler=scheduler, 348 logger=trainers.SamLogger, 349 log_image_interval=100, 350 mixed_precision=True, 351 convert_inputs=convert_inputs, 352 n_objects_per_batch=n_objects_per_batch, 353 n_sub_iteration=n_sub_iteration, 354 compile_model=False, 355 early_stopping=early_stopping, 356 mask_prob=mask_prob, 357 save_root=save_root, 358 ) 359 360 trainer_fit_params = _get_trainer_fit_params( 361 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 362 ) 363 trainer.fit(**trainer_fit_params) 364 365 t_run = time.time() - t_start 366 hours = int(t_run // 3600) 367 minutes = int(t_run // 60) 368 seconds = int(round(t_run % 60, 0)) 369 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 370 371 372def export_instance_segmentation_model( 373 trained_model_path: Union[str, os.PathLike], 374 output_path: Union[str, os.PathLike], 375 model_type: str, 376 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 377) -> None: 378 """Export a model trained for instance segmentation with `train_instance_segmentation`. 379 380 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 381 It should only be used for automatic segmentation and may not work well for interactive segmentation. 382 383 Args: 384 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 385 output_path: The path where the exported model will be saved. 386 model_type: The model type. 387 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 388 """ 389 trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] 390 391 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 392 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 393 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 394 395 # Load the original state of the model that was used as the basis of instance segmentation training. 396 _, model_state = get_sam_model( 397 model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", 398 ) 399 # Remove the sam prefix if it's in the model state. 400 prefix = "sam." 401 model_state = OrderedDict( 402 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 403 ) 404 405 # Replace the image encoder state. 406 model_state = OrderedDict( 407 [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) 408 for k, v in model_state.items()] 409 ) 410 411 save_state = {"model_state": model_state, "decoder_state": decoder_state} 412 torch.save(save_state, output_path) 413 414 415def train_instance_segmentation( 416 name: str, 417 model_type: str, 418 train_loader: DataLoader, 419 val_loader: DataLoader, 420 n_epochs: int = 100, 421 early_stopping: Optional[int] = 10, 422 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 423 metric: Optional[torch.nn.Module] = None, 424 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 425 freeze: Optional[List[str]] = None, 426 device: Optional[Union[str, torch.device]] = None, 427 lr: float = 1e-5, 428 save_root: Optional[Union[str, os.PathLike]] = None, 429 n_iterations: Optional[int] = None, 430 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 431 scheduler_kwargs: Optional[Dict[str, Any]] = None, 432 save_every_kth_epoch: Optional[int] = None, 433 pbar_signals: Optional[QObject] = None, 434 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 435 peft_kwargs: Optional[Dict] = None, 436 ignore_warnings: bool = True, 437 overwrite_training: bool = True, 438 **model_kwargs, 439) -> None: 440 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 441 442 This setting corresponds to training a SAM model with an instance segmentation decoder, 443 without training the model parts for interactive segmentation, 444 i.e. without training the prompt encoder and mask decoder. 445 446 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 447 will not be compatible with the micro_sam functionality. 448 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 449 in a format that is compatible with micro_sam functionality. 450 Note that the exported model should only be used for automatic segmentation via AIS. 451 452 Args: 453 name: The name of the model to be trained. The checkpoint and logs will have this name. 454 model_type: The type of the SAM model. 455 train_loader: The dataloader for training. 456 val_loader: The dataloader for validation. 457 n_epochs: The number of epochs to train for. 458 early_stopping: Enable early stopping after this number of epochs without improvement. 459 By default, the value is set to '10' epochs. 460 loss: The loss function to train the instance segmentation model. 461 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 462 metric: The metric for the instance segmentation training. 463 By default the loss function is used as the metric. 464 checkpoint_path: Path to checkpoint for initializing the SAM model. 465 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 466 By default nothing is frozen and the full model is updated. 467 device: The device to use for training. By default, automatically chooses the best available device to train. 468 lr: The learning rate. By default, set to '1e-5'. 469 save_root: Optional root directory for saving the checkpoints and logs. 470 If not given the current working directory is used. 471 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 472 scheduler_class: The learning rate scheduler to update the learning rate. 473 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 474 scheduler_kwargs: The learning rate scheduler parameters. 475 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 476 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 477 pbar_signals: Controls for napari progress bar. 478 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 479 peft_kwargs: Keyword arguments for the PEFT wrapper class. 480 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 481 overwrite_training: Whether to overwrite the trained model stored at the same location. 482 By default, overwrites the trained model at each run. 483 If set to 'False', it will avoid retraining the model if the previous run was completed. 484 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 485 """ 486 487 with _filter_warnings(ignore_warnings): 488 t_start = time.time() 489 490 sam_model, state = get_trainable_sam_model( 491 model_type=model_type, 492 device=device, 493 checkpoint_path=checkpoint_path, 494 return_state=True, 495 peft_kwargs=peft_kwargs, 496 freeze=freeze, 497 **model_kwargs 498 ) 499 device = get_device(device) 500 model = get_unetr( 501 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 502 ) 503 504 optimizer, scheduler = _get_optimizer_and_scheduler( 505 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 506 ) 507 trainer = torch_em.trainer.DefaultTrainer( 508 name=name, 509 model=model, 510 train_loader=train_loader, 511 val_loader=val_loader, 512 device=device, 513 mixed_precision=True, 514 log_image_interval=50, 515 compile_model=False, 516 save_root=save_root, 517 loss=loss, 518 metric=loss if metric is None else metric, 519 optimizer=optimizer, 520 lr_scheduler=scheduler, 521 early_stopping=early_stopping, 522 ) 523 524 trainer_fit_params = _get_trainer_fit_params( 525 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 526 ) 527 trainer.fit(**trainer_fit_params) 528 529 t_run = time.time() - t_start 530 hours = int(t_run // 3600) 531 minutes = int(t_run // 60) 532 seconds = int(round(t_run % 60, 0)) 533 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 534 535 536def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 537 if isinstance(raw_paths, (str, os.PathLike)): 538 path = raw_paths 539 else: 540 path = raw_paths[0] 541 assert isinstance(path, (str, os.PathLike, np.ndarray, torch.Tensor)) 542 543 # Check the underlying data dimensionality. 544 if raw_key is None and isinstance(path, (np.ndarray, torch.Tensor)): 545 ndim = path.ndim 546 elif raw_key is None: # If no key is given and this is not a tensor/array then we assume it's an image file. 547 ndim = imageio.imread(path).ndim 548 else: # Otherwise we try to open the file from key. 549 try: # First try to open it with elf. 550 with open_file(path, "r") as f: 551 ndim = f[raw_key].ndim 552 except ValueError: # This may fail for images in a folder with different sizes. 553 # In that case we read one of the images. 554 image_path = glob(os.path.join(path, raw_key))[0] 555 ndim = imageio.imread(image_path).ndim 556 557 if not isinstance(patch_shape, tuple): 558 patch_shape = tuple(patch_shape) 559 560 if ndim == 2: 561 assert len(patch_shape) == 2 562 return patch_shape 563 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 564 return (1,) + patch_shape 565 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 566 return (1,) + patch_shape 567 else: 568 return patch_shape 569 570 571def _determine_dataset_and_channels(raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs): 572 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 573 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 574 575 # Check if the input data is in numpy or torch. In this case determine we use 576 # the image collection dataset heuristic (torch_em will figure it out), 577 # and we determine the number of channels. 578 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)): 579 is_seg_dataset = False 580 if with_channels is None: 581 with_channels = raw_paths[0].ndim == 3 and raw_paths.shape[-1] == 3 582 return is_seg_dataset, with_channels 583 584 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 585 # Get valid raw paths to make checks possible. 586 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 587 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 588 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 589 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 590 591 # Load one of the raw inputs to validate whether it is RGB or not. 592 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 593 if test_raw_inputs.ndim == 3: 594 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 595 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 596 # We need to provide a list of inputs to 'ImageCollectionDataset'. 597 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 598 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 599 600 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 601 with_channels = False if with_channels is None else with_channels 602 603 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 604 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 605 with_channels = True if with_channels is None else with_channels 606 607 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 608 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 609 with_channels = False if with_channels is None else with_channels 610 611 return is_seg_dataset, with_channels 612 613 614def default_sam_dataset( 615 raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 616 raw_key: Optional[str], 617 label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 618 label_key: Optional[str], 619 patch_shape: Tuple[int], 620 with_segmentation_decoder: bool, 621 with_channels: Optional[bool] = None, 622 train_instance_segmentation_only: bool = False, 623 sampler: Optional[Callable] = None, 624 raw_transform: Optional[Callable] = None, 625 n_samples: Optional[int] = None, 626 is_train: bool = True, 627 min_size: int = 25, 628 max_sampling_attempts: Optional[int] = None, 629 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 630 is_multi_tensor: bool = True, 631 **kwargs, 632) -> Dataset: 633 """Create a PyTorch Dataset for training a SAM model. 634 635 Args: 636 raw_paths: The path(s) to the image data used for training. 637 Can either be multiple 2D images or volumetric data. 638 The data can also be passed as a list of numpy arrays or torch tensors. 639 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 640 or a glob pattern for selecting multiple files. 641 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 642 label_paths: The path(s) to the label data used for training. 643 Can either be multiple 2D images or volumetric data. 644 The data can also be passed as a list of numpy arrays or torch tensors. 645 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 646 or a glob pattern for selecting multiple files. 647 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 648 patch_shape: The shape for training patches. 649 with_segmentation_decoder: Whether to train with additional segmentation decoder. 650 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 651 train_instance_segmentation_only: Set this argument to True in order to 652 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 653 sampler: A sampler to reject batches according to a given criterion. 654 raw_transform: Transformation applied to the image data. 655 If not given the data will be cast to 8bit. 656 n_samples: The number of samples for this dataset. 657 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 658 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 659 max_sampling_attempts: Number of sampling attempts to make from a dataset. 660 rois: The region of interest(s) for the data. 661 is_multi_tensor: Whether the input data to data transforms is multiple tensors or not. 662 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 663 664 Returns: 665 The segmentation dataset. 666 """ 667 668 # Check if this dataset should be used for instance segmentation only training. 669 # If yes, we set return_instances to False, since the instance channel must not 670 # be passed for this training mode. 671 return_instances = True 672 if train_instance_segmentation_only: 673 if not with_segmentation_decoder: 674 raise ValueError( 675 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 676 ) 677 return_instances = False 678 679 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 680 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 681 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 682 # and we do not set a default sampler. 683 if sampler is None and not train_instance_segmentation_only: 684 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 685 686 is_seg_dataset, with_channels = _determine_dataset_and_channels( 687 raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs 688 ) 689 690 # Set the data transformations. 691 if raw_transform is None: 692 raw_transform = require_8bit 693 694 # Prepare the label transform. 695 if with_segmentation_decoder: 696 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 697 distances=True, 698 boundary_distances=True, 699 directed_distances=False, 700 foreground=True, 701 instances=return_instances, 702 min_size=min_size, 703 ) 704 else: 705 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 706 707 # Allow combining label transforms. 708 custom_label_transform = kwargs.pop("label_transform", None) 709 if custom_label_transform is None: 710 label_transform = default_label_transform 711 else: 712 label_transform = torch_em.transform.generic.Compose( 713 custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor 714 ) 715 716 # Check the patch shape to add a singleton if required. 717 patch_shape = _update_patch_shape( 718 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 719 ) 720 721 # Set a minimum number of samples per epoch. 722 if n_samples is None: 723 loader = torch_em.default_segmentation_loader( 724 raw_paths=raw_paths, 725 raw_key=raw_key, 726 label_paths=label_paths, 727 label_key=label_key, 728 batch_size=1, 729 patch_shape=patch_shape, 730 with_channels=with_channels, 731 ndim=2, 732 is_seg_dataset=is_seg_dataset, 733 raw_transform=raw_transform, 734 rois=rois, 735 **kwargs 736 ) 737 n_samples = max(len(loader), 100 if is_train else 5) 738 739 dataset = torch_em.default_segmentation_dataset( 740 raw_paths=raw_paths, 741 raw_key=raw_key, 742 label_paths=label_paths, 743 label_key=label_key, 744 patch_shape=patch_shape, 745 raw_transform=raw_transform, 746 label_transform=label_transform, 747 with_channels=with_channels, 748 ndim=2, 749 sampler=sampler, 750 n_samples=n_samples, 751 is_seg_dataset=is_seg_dataset, 752 rois=rois, 753 **kwargs, 754 ) 755 756 if max_sampling_attempts is not None: 757 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 758 for ds in dataset.datasets: 759 ds.max_sampling_attempts = max_sampling_attempts 760 else: 761 dataset.max_sampling_attempts = max_sampling_attempts 762 763 return dataset 764 765 766def default_sam_loader(**kwargs) -> DataLoader: 767 """Create a PyTorch DataLoader for training a SAM model. 768 769 Args: 770 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 771 772 Returns: 773 The DataLoader. 774 """ 775 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 776 777 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 778 # which the users can provide to get their desired segmentation dataset. 779 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 780 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 781 782 ds = default_sam_dataset(**ds_kwargs) 783 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 784 785 786CONFIGURATIONS = { 787 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 788 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 789 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 790 "gtx3080": { 791 "model_type": "vit_b", "n_objects_per_batch": 5, 792 "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery} 793 }, 794 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 795 "V100": {"model_type": "vit_b"}, 796 "A100": {"model_type": "vit_h"}, 797} 798"""Best training configurations for given hardware resources. 799""" 800 801 802def _find_best_configuration(): 803 if torch.cuda.is_available(): 804 805 # Check how much memory we have and select the best matching GPU 806 # for the available VRAM size. 807 _, vram = torch.cuda.mem_get_info() 808 vram = vram / 1e9 # in GB 809 810 # Maybe we can get more configurations in the future. 811 if vram > 80: # More than 80 GB: use the A100 configurations. 812 return "A100" 813 elif vram > 30: # More than 30 GB: use the V100 configurations. 814 return "V100" 815 elif vram > 14: # More than 14 GB: use the RTX5000 configurations. 816 return "rtx5000" 817 elif vram > 8: # More than 8 GB: use the GTX3080 configurations. 818 return "gtx3080" 819 else: # Otherwise: not enough memory to train on the GPU, use CPU instead. 820 return "CPU" 821 else: 822 return "CPU" 823 824 825def train_sam_for_configuration( 826 name: str, 827 train_loader: DataLoader, 828 val_loader: DataLoader, 829 configuration: Optional[str] = None, 830 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 831 with_segmentation_decoder: bool = True, 832 train_instance_segmentation_only: bool = False, 833 model_type: Optional[str] = None, 834 **kwargs, 835) -> None: 836 """Run training for a SAM model with the configuration for a given hardware resource. 837 838 Selects the best training settings for the given configuration. 839 The available configurations are listed in `CONFIGURATIONS`. 840 841 Args: 842 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 843 train_loader: The dataloader for training. 844 val_loader: The dataloader for validation. 845 configuration: The configuration (= name of hardware resource). 846 By default, it is automatically selected for the best VRAM combination. 847 checkpoint_path: Path to checkpoint for initializing the SAM model. 848 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 849 By default, trains with the additional instance segmentation decoder. 850 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 851 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 852 model_type: Over-ride the default model type. 853 This can be used to use one of the micro_sam models as starting point 854 instead of a default sam model. 855 kwargs: Additional keyword parameters that will be passed to `train_sam`. 856 """ 857 if configuration is None: # Automatically choose based on available VRAM combination. 858 configuration = _find_best_configuration() 859 860 if configuration in CONFIGURATIONS: 861 train_kwargs = CONFIGURATIONS[configuration] 862 else: 863 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 864 865 if model_type is None: 866 model_type = train_kwargs.pop("model_type") 867 else: 868 expected_model_type = train_kwargs.pop("model_type") 869 if model_type[:5] != expected_model_type: 870 warnings.warn("You have specified a different model type.") 871 872 train_kwargs.update(**kwargs) 873 if train_instance_segmentation_only: 874 instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs) 875 model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs) 876 instance_seg_kwargs.update(**model_kwargs) 877 878 train_instance_segmentation( 879 name=name, 880 train_loader=train_loader, 881 val_loader=val_loader, 882 checkpoint_path=checkpoint_path, 883 model_type=model_type, 884 **instance_seg_kwargs, 885 ) 886 else: 887 train_sam( 888 name=name, 889 train_loader=train_loader, 890 val_loader=val_loader, 891 checkpoint_path=checkpoint_path, 892 with_segmentation_decoder=with_segmentation_decoder, 893 model_type=model_type, 894 **train_kwargs 895 ) 896 897 898def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): 899 900 # Whether the model is stored in the current working directory or in another location. 901 if save_root is None: 902 save_root = os.getcwd() # Map this to current working directory, if not specified by the user. 903 904 # Get the 'best' model checkpoint ready for export. 905 best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") 906 if not os.path.exists(best_checkpoint): 907 raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.") 908 909 # Export the model if an output path has been given. 910 if output_path: 911 912 # If the filepath has a pytorch-specific ending, then we just export the checkpoint. 913 if os.path.splitext(output_path)[1] in (".pt", ".pth"): 914 export_custom_sam_model( 915 checkpoint_path=best_checkpoint, 916 model_type=model_type[:5], 917 save_path=output_path, 918 with_segmentation_decoder=with_segmentation_decoder, 919 ) 920 921 # Otherwise we export it as bioimage.io model. 922 else: 923 from micro_sam.bioimageio import export_sam_model 924 925 # Load image and corresponding labels from the val loader. 926 with torch.no_grad(): 927 image_data, label_data = next(iter(val_loader)) 928 image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze() 929 930 # Select the first channel of the label image if we have a channel axis, i.e. contains the labels 931 if label_data.ndim == 3: 932 label_data = label_data[0] # Gets the channel with instances. 933 assert image_data.shape == label_data.shape 934 label_data = label_data.astype("uint32") 935 936 export_sam_model( 937 image=image_data, 938 label_image=label_data, 939 model_type=model_type[:5], 940 name=checkpoint_name, 941 output_path=output_path, 942 checkpoint_path=best_checkpoint, 943 ) 944 945 # The final path where the model has been stored. 946 final_path = output_path 947 948 else: # If no exports have been made, inform the user about the best checkpoint. 949 final_path = best_checkpoint 950 951 return final_path 952 953 954def _parse_segmentation_decoder(segmentation_decoder): 955 if segmentation_decoder in ("None", "none"): 956 with_segmentation_decoder, train_instance_segmentation_only = False, False 957 elif segmentation_decoder == "instances": 958 with_segmentation_decoder, train_instance_segmentation_only = True, False 959 elif segmentation_decoder == "instances_only": 960 with_segmentation_decoder, train_instance_segmentation_only = True, True 961 else: 962 raise ValueError( 963 "The 'segmentation_decoder' argument currently supports the values:\n" 964 f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}." 965 ) 966 return with_segmentation_decoder, train_instance_segmentation_only 967 968 969def main(): 970 """@private""" 971 import argparse 972 973 available_models = list(get_model_names()) 974 available_models = ", ".join(available_models) 975 976 available_configurations = list(CONFIGURATIONS.keys()) 977 available_configurations = ", ".join(available_configurations) 978 979 parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.") 980 981 # Images and labels for training. 982 parser.add_argument( 983 "--images", required=True, type=str, nargs="*", 984 help="Filepath to images or the directory where the image data is stored." 985 ) 986 parser.add_argument( 987 "--labels", required=True, type=str, nargs="*", 988 help="Filepath to ground-truth labels or the directory where the label data is stored." 989 ) 990 parser.add_argument( 991 "--image_key", type=str, default=None, 992 help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. " 993 ) 994 parser.add_argument( 995 "--label_key", type=str, default=None, 996 help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. " 997 ) 998 999 # Images and labels for validation. 1000 # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided. 1001 # Users can choose to have their explicit validation set via this feature as well. 1002 parser.add_argument( 1003 "--val_images", type=str, nargs="*", 1004 help="Filepath to images for validation or the directory where the image data is stored." 1005 ) 1006 parser.add_argument( 1007 "--val_labels", type=str, nargs="*", 1008 help="Filepath to ground-truth labels for validation or the directory where the label data is stored." 1009 ) 1010 parser.add_argument( 1011 "--val_image_key", type=str, default=None, 1012 help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file." 1013 ) 1014 parser.add_argument( 1015 "--val_label_key", type=str, default=None, 1016 help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file." 1017 ) 1018 1019 # Other necessary stuff for training. 1020 parser.add_argument( 1021 "--configuration", type=str, default=_find_best_configuration(), 1022 help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}." 1023 ) 1024 1025 def none_or_str(value): 1026 if value.lower() == 'none': 1027 return None 1028 return value 1029 1030 # This could be extended to train for semantic segmentation or other options. 1031 parser.add_argument( 1032 "--segmentation_decoder", type=none_or_str, default="instances", 1033 help="Whether to finetune Segment Anything Model with an additional segmentation decoder. " 1034 "The following options are possible:\n" 1035 "- 'instances' to train with an additional decoder for automatic instance segmentation. " 1036 " This option enables using the automatic instance segmentation (AIS) mode.\n" 1037 "- 'instances_only' to train only the instance segmentation decoder. " 1038 " In this case the parts of SAM that are used for interactive segmentation will not be trained.\n" 1039 "- 'None' to train without an additional segmentation decoder." 1040 " This options trains only the parts of the original SAM.\n" 1041 "By default the option 'instances' is used." 1042 ) 1043 1044 # Optional advanced settings a user can opt to change the values for. 1045 parser.add_argument( 1046 "-d", "--device", type=str, default=None, 1047 help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). " 1048 "By default the most performant available device will be selected." 1049 ) 1050 parser.add_argument( 1051 "--patch_shape", type=int, nargs="*", default=(512, 512), 1052 help="The choice of patch shape for training Segment Anything Model. " 1053 "By default, a patch size of 512x512 is used." 1054 ) 1055 parser.add_argument( 1056 "-m", "--model_type", type=str, default=None, 1057 help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}." 1058 ) 1059 parser.add_argument( 1060 "--checkpoint_path", type=str, default=None, 1061 help="Checkpoint from which the SAM model will be loaded for finetuning." 1062 ) 1063 parser.add_argument( 1064 "-s", "--save_root", type=str, default=None, 1065 help="The directory where the trained models and corresponding logs will be stored. " 1066 "By default, there are stored in your current working directory." 1067 ) 1068 parser.add_argument( 1069 "--trained_model_name", type=str, default="sam_model", 1070 help="The custom name of trained model sub-folder. Allows users to have several trained models " 1071 "under the same 'save_root'." 1072 ) 1073 parser.add_argument( 1074 "--output_path", type=str, default=None, 1075 help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model." 1076 ) 1077 parser.add_argument( 1078 "--n_epochs", type=int, default=100, 1079 help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs." 1080 ) 1081 parser.add_argument( 1082 "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders." 1083 ) 1084 parser.add_argument( 1085 "--batch_size", type=int, default=1, 1086 help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1." 1087 ) 1088 parser.add_argument( 1089 "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"), 1090 help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images " 1091 "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'." 1092 ) 1093 1094 args = parser.parse_args() 1095 1096 # 1. Get all necessary stuff for training. 1097 checkpoint_name = args.trained_model_name 1098 config = args.configuration 1099 model_type = args.model_type 1100 checkpoint_path = args.checkpoint_path 1101 batch_size = args.batch_size 1102 patch_shape = args.patch_shape 1103 epochs = args.n_epochs 1104 num_workers = args.num_workers 1105 device = args.device 1106 save_root = args.save_root 1107 output_path = args.output_path 1108 with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder) 1109 1110 # Get image paths and corresponding keys. 1111 train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key 1112 val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key # noqa 1113 1114 # 2. Prepare the dataloaders. 1115 1116 # If the user wants to preprocess the inputs, we allow the possibility to do so. 1117 _raw_transform = get_raw_transform(args.preprocess) 1118 1119 # Get the dataset with files for training. 1120 dataset = default_sam_dataset( 1121 raw_paths=train_images, 1122 raw_key=train_image_key, 1123 label_paths=train_gt, 1124 label_key=train_gt_key, 1125 patch_shape=patch_shape, 1126 with_segmentation_decoder=with_segmentation_decoder, 1127 raw_transform=_raw_transform, 1128 train_instance_segmentation_only=train_instance_segmentation_only, 1129 ) 1130 1131 # If val images are not exclusively provided, we create a val split from the training data. 1132 if val_images is None: 1133 assert val_gt is None and val_image_key is None and val_gt_key is None 1134 # Use 10% of the dataset for validation - at least one image - for validation. 1135 n_val = max(1, int(0.1 * len(dataset))) 1136 train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val]) 1137 1138 else: # If val images provided, we create a new dataset for it. 1139 train_dataset = dataset 1140 val_dataset = default_sam_dataset( 1141 raw_paths=val_images, 1142 raw_key=val_image_key, 1143 label_paths=val_gt, 1144 label_key=val_gt_key, 1145 patch_shape=patch_shape, 1146 with_segmentation_decoder=with_segmentation_decoder, 1147 train_instance_segmentation_only=train_instance_segmentation_only, 1148 raw_transform=_raw_transform, 1149 ) 1150 1151 # Get the dataloaders from the datasets. 1152 train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 1153 val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers) 1154 1155 # 3. Train the Segment Anything Model. 1156 1157 # Get a valid model and other necessary parameters for training. 1158 if model_type is not None and model_type not in available_models: 1159 raise ValueError(f"'{model_type}' is not a valid choice of model.") 1160 if config is not None and config not in available_configurations: 1161 raise ValueError(f"'{config}' is not a valid choice of configuration.") 1162 1163 if model_type is None: # If user does not specify the model, we use the default model corresponding to the config. 1164 model_type = CONFIGURATIONS[config]["model_type"] 1165 1166 train_sam_for_configuration( 1167 name=checkpoint_name, 1168 configuration=config, 1169 model_type=model_type, 1170 train_loader=train_loader, 1171 val_loader=val_loader, 1172 n_epochs=epochs, 1173 checkpoint_path=checkpoint_path, 1174 with_segmentation_decoder=with_segmentation_decoder, 1175 freeze=None, # TODO: Allow for PEFT. 1176 device=device, 1177 save_root=save_root, 1178 peft_kwargs=None, # TODO: Allow for PEFT. 1179 train_instance_segmentation_only=train_instance_segmentation_only, 1180 ) 1181 1182 # 4. Export the model, if desired by the user 1183 if train_instance_segmentation_only and output_path: 1184 trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt") 1185 export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path) 1186 final_path = output_path 1187 else: 1188 final_path = _export_helper( 1189 save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader, 1190 ) 1191 1192 print(f"Training has finished. The trained model is saved at {final_path}.")
197def train_sam( 198 name: str, 199 model_type: str, 200 train_loader: DataLoader, 201 val_loader: DataLoader, 202 n_epochs: int = 100, 203 early_stopping: Optional[int] = 10, 204 n_objects_per_batch: Optional[int] = 25, 205 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 206 with_segmentation_decoder: bool = True, 207 freeze: Optional[List[str]] = None, 208 device: Optional[Union[str, torch.device]] = None, 209 lr: float = 1e-5, 210 n_sub_iteration: int = 8, 211 save_root: Optional[Union[str, os.PathLike]] = None, 212 mask_prob: float = 0.5, 213 n_iterations: Optional[int] = None, 214 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 215 scheduler_kwargs: Optional[Dict[str, Any]] = None, 216 save_every_kth_epoch: Optional[int] = None, 217 pbar_signals: Optional[QObject] = None, 218 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 219 peft_kwargs: Optional[Dict] = None, 220 ignore_warnings: bool = True, 221 verify_n_labels_in_loader: Optional[int] = 50, 222 box_distortion_factor: Optional[float] = 0.025, 223 overwrite_training: bool = True, 224 **model_kwargs, 225) -> None: 226 """Run training for a SAM model. 227 228 Args: 229 name: The name of the model to be trained. The checkpoint and logs will have this name. 230 model_type: The type of the SAM model. 231 train_loader: The dataloader for training. 232 val_loader: The dataloader for validation. 233 n_epochs: The number of epochs to train for. 234 early_stopping: Enable early stopping after this number of epochs without improvement. 235 By default, the value is set to '10' epochs. 236 n_objects_per_batch: The number of objects per batch used to compute 237 the loss for interative segmentation. If None all objects will be used, 238 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 239 checkpoint_path: Path to checkpoint for initializing the SAM model. 240 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 241 By default, trains with the additional instance segmentation decoder. 242 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 243 By default nothing is frozen and the full model is updated. 244 device: The device to use for training. By default, automatically chooses the best available device to train. 245 lr: The learning rate. By default, set to '1e-5'. 246 n_sub_iteration: The number of iterative prompts per training iteration. 247 By default, the number of iterations is set to '8'. 248 save_root: Optional root directory for saving the checkpoints and logs. 249 If not given the current working directory is used. 250 mask_prob: The probability for using a mask as input in a given training sub-iteration. 251 By default, set to '0.5'. 252 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 253 scheduler_class: The learning rate scheduler to update the learning rate. 254 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 255 scheduler_kwargs: The learning rate scheduler parameters. 256 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 257 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 258 pbar_signals: Controls for napari progress bar. 259 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 260 peft_kwargs: Keyword arguments for the PEFT wrapper class. 261 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 262 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 263 By default, 50 batches of labels are verified from the dataloaders. 264 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 265 By default, the distortion factor is set to '0.025'. 266 overwrite_training: Whether to overwrite the trained model stored at the same location. 267 By default, overwrites the trained model at each run. 268 If set to 'False', it will avoid retraining the model if the previous run was completed. 269 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 270 """ 271 with _filter_warnings(ignore_warnings): 272 273 t_start = time.time() 274 if verify_n_labels_in_loader is not None: 275 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 276 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 277 278 device = get_device(device) 279 # Get the trainable segment anything model. 280 model, state = get_trainable_sam_model( 281 model_type=model_type, 282 device=device, 283 freeze=freeze, 284 checkpoint_path=checkpoint_path, 285 return_state=True, 286 peft_kwargs=peft_kwargs, 287 **model_kwargs 288 ) 289 290 # This class creates all the training data for a batch (inputs, prompts and labels). 291 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 292 293 # Create the UNETR decoder (if train with it) and the optimizer. 294 if with_segmentation_decoder: 295 296 # Get the UNETR. 297 unetr = get_unetr( 298 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 299 ) 300 301 # Get the parameters for SAM and the decoder from UNETR. 302 joint_model_params = [params for params in model.parameters()] # sam parameters 303 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 304 if not param_name.startswith("encoder"): 305 joint_model_params.append(params) 306 307 model_params = joint_model_params 308 else: 309 model_params = model.parameters() 310 311 optimizer, scheduler = _get_optimizer_and_scheduler( 312 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 313 ) 314 315 # The trainer which performs training and validation. 316 if with_segmentation_decoder: 317 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 318 trainer = joint_trainers.JointSamTrainer( 319 name=name, 320 save_root=save_root, 321 train_loader=train_loader, 322 val_loader=val_loader, 323 model=model, 324 optimizer=optimizer, 325 device=device, 326 lr_scheduler=scheduler, 327 logger=joint_trainers.JointSamLogger, 328 log_image_interval=100, 329 mixed_precision=True, 330 convert_inputs=convert_inputs, 331 n_objects_per_batch=n_objects_per_batch, 332 n_sub_iteration=n_sub_iteration, 333 compile_model=False, 334 unetr=unetr, 335 instance_loss=instance_seg_loss, 336 instance_metric=instance_seg_loss, 337 early_stopping=early_stopping, 338 mask_prob=mask_prob, 339 ) 340 else: 341 trainer = trainers.SamTrainer( 342 name=name, 343 train_loader=train_loader, 344 val_loader=val_loader, 345 model=model, 346 optimizer=optimizer, 347 device=device, 348 lr_scheduler=scheduler, 349 logger=trainers.SamLogger, 350 log_image_interval=100, 351 mixed_precision=True, 352 convert_inputs=convert_inputs, 353 n_objects_per_batch=n_objects_per_batch, 354 n_sub_iteration=n_sub_iteration, 355 compile_model=False, 356 early_stopping=early_stopping, 357 mask_prob=mask_prob, 358 save_root=save_root, 359 ) 360 361 trainer_fit_params = _get_trainer_fit_params( 362 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 363 ) 364 trainer.fit(**trainer_fit_params) 365 366 t_run = time.time() - t_start 367 hours = int(t_run // 3600) 368 minutes = int(t_run // 60) 369 seconds = int(round(t_run % 60, 0)) 370 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochsif given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateauis used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamWis used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model.
373def export_instance_segmentation_model( 374 trained_model_path: Union[str, os.PathLike], 375 output_path: Union[str, os.PathLike], 376 model_type: str, 377 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 378) -> None: 379 """Export a model trained for instance segmentation with `train_instance_segmentation`. 380 381 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 382 It should only be used for automatic segmentation and may not work well for interactive segmentation. 383 384 Args: 385 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 386 output_path: The path where the exported model will be saved. 387 model_type: The model type. 388 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 389 """ 390 trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] 391 392 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 393 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 394 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 395 396 # Load the original state of the model that was used as the basis of instance segmentation training. 397 _, model_state = get_sam_model( 398 model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", 399 ) 400 # Remove the sam prefix if it's in the model state. 401 prefix = "sam." 402 model_state = OrderedDict( 403 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 404 ) 405 406 # Replace the image encoder state. 407 model_state = OrderedDict( 408 [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) 409 for k, v in model_state.items()] 410 ) 411 412 save_state = {"model_state": model_state, "decoder_state": decoder_state} 413 torch.save(save_state, output_path)
Export a model trained for instance segmentation with train_instance_segmentation.
The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.
Arguments:
- trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
- output_path: The path where the exported model will be saved.
- model_type: The model type.
- initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
416def train_instance_segmentation( 417 name: str, 418 model_type: str, 419 train_loader: DataLoader, 420 val_loader: DataLoader, 421 n_epochs: int = 100, 422 early_stopping: Optional[int] = 10, 423 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 424 metric: Optional[torch.nn.Module] = None, 425 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 426 freeze: Optional[List[str]] = None, 427 device: Optional[Union[str, torch.device]] = None, 428 lr: float = 1e-5, 429 save_root: Optional[Union[str, os.PathLike]] = None, 430 n_iterations: Optional[int] = None, 431 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 432 scheduler_kwargs: Optional[Dict[str, Any]] = None, 433 save_every_kth_epoch: Optional[int] = None, 434 pbar_signals: Optional[QObject] = None, 435 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 436 peft_kwargs: Optional[Dict] = None, 437 ignore_warnings: bool = True, 438 overwrite_training: bool = True, 439 **model_kwargs, 440) -> None: 441 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 442 443 This setting corresponds to training a SAM model with an instance segmentation decoder, 444 without training the model parts for interactive segmentation, 445 i.e. without training the prompt encoder and mask decoder. 446 447 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 448 will not be compatible with the micro_sam functionality. 449 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 450 in a format that is compatible with micro_sam functionality. 451 Note that the exported model should only be used for automatic segmentation via AIS. 452 453 Args: 454 name: The name of the model to be trained. The checkpoint and logs will have this name. 455 model_type: The type of the SAM model. 456 train_loader: The dataloader for training. 457 val_loader: The dataloader for validation. 458 n_epochs: The number of epochs to train for. 459 early_stopping: Enable early stopping after this number of epochs without improvement. 460 By default, the value is set to '10' epochs. 461 loss: The loss function to train the instance segmentation model. 462 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 463 metric: The metric for the instance segmentation training. 464 By default the loss function is used as the metric. 465 checkpoint_path: Path to checkpoint for initializing the SAM model. 466 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 467 By default nothing is frozen and the full model is updated. 468 device: The device to use for training. By default, automatically chooses the best available device to train. 469 lr: The learning rate. By default, set to '1e-5'. 470 save_root: Optional root directory for saving the checkpoints and logs. 471 If not given the current working directory is used. 472 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 473 scheduler_class: The learning rate scheduler to update the learning rate. 474 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 475 scheduler_kwargs: The learning rate scheduler parameters. 476 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 477 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 478 pbar_signals: Controls for napari progress bar. 479 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 480 peft_kwargs: Keyword arguments for the PEFT wrapper class. 481 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 482 overwrite_training: Whether to overwrite the trained model stored at the same location. 483 By default, overwrites the trained model at each run. 484 If set to 'False', it will avoid retraining the model if the previous run was completed. 485 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 486 """ 487 488 with _filter_warnings(ignore_warnings): 489 t_start = time.time() 490 491 sam_model, state = get_trainable_sam_model( 492 model_type=model_type, 493 device=device, 494 checkpoint_path=checkpoint_path, 495 return_state=True, 496 peft_kwargs=peft_kwargs, 497 freeze=freeze, 498 **model_kwargs 499 ) 500 device = get_device(device) 501 model = get_unetr( 502 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), device=device, 503 ) 504 505 optimizer, scheduler = _get_optimizer_and_scheduler( 506 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 507 ) 508 trainer = torch_em.trainer.DefaultTrainer( 509 name=name, 510 model=model, 511 train_loader=train_loader, 512 val_loader=val_loader, 513 device=device, 514 mixed_precision=True, 515 log_image_interval=50, 516 compile_model=False, 517 save_root=save_root, 518 loss=loss, 519 metric=loss if metric is None else metric, 520 optimizer=optimizer, 521 lr_scheduler=scheduler, 522 early_stopping=early_stopping, 523 ) 524 525 trainer_fit_params = _get_trainer_fit_params( 526 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 527 ) 528 trainer.fit(**trainer_fit_params) 529 530 t_run = time.time() - t_start 531 hours = int(t_run // 3600) 532 minutes = int(t_run // 60) 533 seconds = int(round(t_run % 60, 0)) 534 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Train a UNETR for instance segmentation using the SAM encoder as backbone.
This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.
The checkpoint of the trained model, which will be saved in 'checkpoints/export_instance_segmentation_model with the path to the checkpoint to export it
in a format that is compatible with micro_sam functionality.
Note that the exported model should only be used for automatic segmentation via AIS.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
- metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochsif given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateauis used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamWis used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model.
615def default_sam_dataset( 616 raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 617 raw_key: Optional[str], 618 label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 619 label_key: Optional[str], 620 patch_shape: Tuple[int], 621 with_segmentation_decoder: bool, 622 with_channels: Optional[bool] = None, 623 train_instance_segmentation_only: bool = False, 624 sampler: Optional[Callable] = None, 625 raw_transform: Optional[Callable] = None, 626 n_samples: Optional[int] = None, 627 is_train: bool = True, 628 min_size: int = 25, 629 max_sampling_attempts: Optional[int] = None, 630 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 631 is_multi_tensor: bool = True, 632 **kwargs, 633) -> Dataset: 634 """Create a PyTorch Dataset for training a SAM model. 635 636 Args: 637 raw_paths: The path(s) to the image data used for training. 638 Can either be multiple 2D images or volumetric data. 639 The data can also be passed as a list of numpy arrays or torch tensors. 640 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 641 or a glob pattern for selecting multiple files. 642 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 643 label_paths: The path(s) to the label data used for training. 644 Can either be multiple 2D images or volumetric data. 645 The data can also be passed as a list of numpy arrays or torch tensors. 646 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 647 or a glob pattern for selecting multiple files. 648 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 649 patch_shape: The shape for training patches. 650 with_segmentation_decoder: Whether to train with additional segmentation decoder. 651 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 652 train_instance_segmentation_only: Set this argument to True in order to 653 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 654 sampler: A sampler to reject batches according to a given criterion. 655 raw_transform: Transformation applied to the image data. 656 If not given the data will be cast to 8bit. 657 n_samples: The number of samples for this dataset. 658 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 659 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 660 max_sampling_attempts: Number of sampling attempts to make from a dataset. 661 rois: The region of interest(s) for the data. 662 is_multi_tensor: Whether the input data to data transforms is multiple tensors or not. 663 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 664 665 Returns: 666 The segmentation dataset. 667 """ 668 669 # Check if this dataset should be used for instance segmentation only training. 670 # If yes, we set return_instances to False, since the instance channel must not 671 # be passed for this training mode. 672 return_instances = True 673 if train_instance_segmentation_only: 674 if not with_segmentation_decoder: 675 raise ValueError( 676 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 677 ) 678 return_instances = False 679 680 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 681 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 682 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 683 # and we do not set a default sampler. 684 if sampler is None and not train_instance_segmentation_only: 685 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 686 687 is_seg_dataset, with_channels = _determine_dataset_and_channels( 688 raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs 689 ) 690 691 # Set the data transformations. 692 if raw_transform is None: 693 raw_transform = require_8bit 694 695 # Prepare the label transform. 696 if with_segmentation_decoder: 697 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 698 distances=True, 699 boundary_distances=True, 700 directed_distances=False, 701 foreground=True, 702 instances=return_instances, 703 min_size=min_size, 704 ) 705 else: 706 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 707 708 # Allow combining label transforms. 709 custom_label_transform = kwargs.pop("label_transform", None) 710 if custom_label_transform is None: 711 label_transform = default_label_transform 712 else: 713 label_transform = torch_em.transform.generic.Compose( 714 custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor 715 ) 716 717 # Check the patch shape to add a singleton if required. 718 patch_shape = _update_patch_shape( 719 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 720 ) 721 722 # Set a minimum number of samples per epoch. 723 if n_samples is None: 724 loader = torch_em.default_segmentation_loader( 725 raw_paths=raw_paths, 726 raw_key=raw_key, 727 label_paths=label_paths, 728 label_key=label_key, 729 batch_size=1, 730 patch_shape=patch_shape, 731 with_channels=with_channels, 732 ndim=2, 733 is_seg_dataset=is_seg_dataset, 734 raw_transform=raw_transform, 735 rois=rois, 736 **kwargs 737 ) 738 n_samples = max(len(loader), 100 if is_train else 5) 739 740 dataset = torch_em.default_segmentation_dataset( 741 raw_paths=raw_paths, 742 raw_key=raw_key, 743 label_paths=label_paths, 744 label_key=label_key, 745 patch_shape=patch_shape, 746 raw_transform=raw_transform, 747 label_transform=label_transform, 748 with_channels=with_channels, 749 ndim=2, 750 sampler=sampler, 751 n_samples=n_samples, 752 is_seg_dataset=is_seg_dataset, 753 rois=rois, 754 **kwargs, 755 ) 756 757 if max_sampling_attempts is not None: 758 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 759 for ds in dataset.datasets: 760 ds.max_sampling_attempts = max_sampling_attempts 761 else: 762 dataset.max_sampling_attempts = max_sampling_attempts 763 764 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
- train_instance_segmentation_only: Set this argument to True in order to
pass the dataset to
train_instance_segmentation. By default, set to 'False'. - sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
- min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- rois: The region of interest(s) for the data.
- is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
767def default_sam_loader(**kwargs) -> DataLoader: 768 """Create a PyTorch DataLoader for training a SAM model. 769 770 Args: 771 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 772 773 Returns: 774 The DataLoader. 775 """ 776 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 777 778 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 779 # which the users can provide to get their desired segmentation dataset. 780 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 781 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 782 783 ds = default_sam_dataset(**ds_kwargs) 784 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_datasetor for the PyTorch DataLoader.
Returns:
The DataLoader.
Best training configurations for given hardware resources.
826def train_sam_for_configuration( 827 name: str, 828 train_loader: DataLoader, 829 val_loader: DataLoader, 830 configuration: Optional[str] = None, 831 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 832 with_segmentation_decoder: bool = True, 833 train_instance_segmentation_only: bool = False, 834 model_type: Optional[str] = None, 835 **kwargs, 836) -> None: 837 """Run training for a SAM model with the configuration for a given hardware resource. 838 839 Selects the best training settings for the given configuration. 840 The available configurations are listed in `CONFIGURATIONS`. 841 842 Args: 843 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 844 train_loader: The dataloader for training. 845 val_loader: The dataloader for validation. 846 configuration: The configuration (= name of hardware resource). 847 By default, it is automatically selected for the best VRAM combination. 848 checkpoint_path: Path to checkpoint for initializing the SAM model. 849 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 850 By default, trains with the additional instance segmentation decoder. 851 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 852 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 853 model_type: Over-ride the default model type. 854 This can be used to use one of the micro_sam models as starting point 855 instead of a default sam model. 856 kwargs: Additional keyword parameters that will be passed to `train_sam`. 857 """ 858 if configuration is None: # Automatically choose based on available VRAM combination. 859 configuration = _find_best_configuration() 860 861 if configuration in CONFIGURATIONS: 862 train_kwargs = CONFIGURATIONS[configuration] 863 else: 864 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 865 866 if model_type is None: 867 model_type = train_kwargs.pop("model_type") 868 else: 869 expected_model_type = train_kwargs.pop("model_type") 870 if model_type[:5] != expected_model_type: 871 warnings.warn("You have specified a different model type.") 872 873 train_kwargs.update(**kwargs) 874 if train_instance_segmentation_only: 875 instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs) 876 model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs) 877 instance_seg_kwargs.update(**model_kwargs) 878 879 train_instance_segmentation( 880 name=name, 881 train_loader=train_loader, 882 val_loader=val_loader, 883 checkpoint_path=checkpoint_path, 884 model_type=model_type, 885 **instance_seg_kwargs, 886 ) 887 else: 888 train_sam( 889 name=name, 890 train_loader=train_loader, 891 val_loader=val_loader, 892 checkpoint_path=checkpoint_path, 893 with_segmentation_decoder=with_segmentation_decoder, 894 model_type=model_type, 895 **train_kwargs 896 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs folder will have this name.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
using the training implementation
train_instance_segmentation. By default,train_samis used. - model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam.