micro_sam.training.training
1import os 2import time 3import warnings 4from glob import glob 5from tqdm import tqdm 6from collections import OrderedDict 7from contextlib import contextmanager, nullcontext 8from typing import Any, Callable, Dict, List, Optional, Tuple, Union 9 10import imageio.v3 as imageio 11import numpy as np 12import torch 13from torch.optim import Optimizer 14from torch.utils.data import random_split 15from torch.utils.data import DataLoader, Dataset 16from torch.optim.lr_scheduler import _LRScheduler 17 18import torch_em 19from torch_em.util import load_data 20from torch_em.data.datasets.util import split_kwargs 21 22from elf.io import open_file 23 24try: 25 from qtpy.QtCore import QObject 26except Exception: 27 QObject = Any 28 29from . import sam_trainer as trainers 30from ..instance_segmentation import get_unetr 31from ..models.peft_sam import ClassicalSurgery 32from . import joint_sam_trainer as joint_trainers 33from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model 34from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform 35 36 37FilePath = Union[str, os.PathLike] 38 39 40def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): 41 x, _ = next(iter(loader)) 42 43 # Raw data: check that we have 1 or 3 channels. 44 n_channels = x.shape[1] 45 if n_channels not in (1, 3): 46 raise ValueError( 47 "Invalid number of channels for the input data from the data loader. " 48 f"Expect 1 or 3 channels, got {n_channels}." 49 ) 50 51 # Raw data: check that it is between [0, 255] 52 minval, maxval = x.min(), x.max() 53 if minval < 0 or minval > 255: 54 raise ValueError( 55 "Invalid data range for the input data from the data loader. " 56 f"The input has to be in range [0, 255], but got minimum value {minval}." 57 ) 58 if maxval < 1 or maxval > 255: 59 raise ValueError( 60 "Invalid data range for the input data from the data loader. " 61 f"The input has to be in range [0, 255], but got maximum value {maxval}." 62 ) 63 64 # Target data: the check depends on whether we train with or without decoder. 65 # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). 66 67 def _check_instance_channel(instance_channel): 68 unique_vals = torch.unique(instance_channel) 69 if (unique_vals < 0).any(): 70 raise ValueError( 71 "The target channel with the instance segmentation must not have negative values." 72 ) 73 if len(unique_vals) == 1: 74 raise ValueError( 75 "The target channel with the instance segmentation must have at least one instance." 76 ) 77 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7): 78 raise ValueError( 79 "All values in the target channel with the instance segmentation must be integer." 80 ) 81 82 counter = 0 83 name = "" if name is None else f"'{name}'" 84 for x, y in tqdm( 85 loader, 86 desc=f"Verifying labels in {name} dataloader", 87 total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, 88 ): 89 n_channels_y = y.shape[1] 90 if with_segmentation_decoder: 91 if n_channels_y != 4: 92 raise ValueError( 93 "Invalid number of channels in the target data from the data loader. " 94 "Expect 4 channel for training with an instance segmentation decoder, " 95 f"but got {n_channels_y} channels." 96 ) 97 # Check instance channel per sample in a batch 98 for per_y_sample in y: 99 _check_instance_channel(per_y_sample[0]) 100 101 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() 102 if targets_min < 0 or targets_min > 1: 103 raise ValueError( 104 "Invalid value range in the target data from the value loader. " 105 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 106 f"to be in range [0.0, 1.0], but got min {targets_min}" 107 ) 108 if targets_max < 0 or targets_max > 1: 109 raise ValueError( 110 "Invalid value range in the target data from the value loader. " 111 "Expect the 3 last target channels (for normalized distances and foreground probabilities) " 112 f"to be in range [0.0, 1.0], but got max {targets_max}" 113 ) 114 115 else: 116 if n_channels_y != 1: 117 raise ValueError( 118 "Invalid number of channels in the target data from the data loader. " 119 "Expect 1 channel for training without an instance segmentation decoder, " 120 f"but got {n_channels_y} channels." 121 ) 122 # Check instance channel per sample in a batch 123 for per_y_sample in y: 124 _check_instance_channel(per_y_sample) 125 126 counter += 1 127 if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: 128 break 129 130 131# Make the progress bar callbacks compatible with a tqdm progress bar interface. 132class _ProgressBarWrapper: 133 def __init__(self, signals): 134 self._signals = signals 135 self._total = None 136 137 @property 138 def total(self): 139 return self._total 140 141 @total.setter 142 def total(self, value): 143 self._signals.pbar_total.emit(value) 144 self._total = value 145 146 def update(self, steps): 147 self._signals.pbar_update.emit(steps) 148 149 def set_description(self, desc, **kwargs): 150 self._signals.pbar_description.emit(desc) 151 152 153@contextmanager 154def _filter_warnings(ignore_warnings): 155 if ignore_warnings: 156 with warnings.catch_warnings(): 157 warnings.simplefilter("ignore") 158 yield 159 else: 160 with nullcontext(): 161 yield 162 163 164def _count_parameters(model_parameters): 165 params = sum(p.numel() for p in model_parameters if p.requires_grad) 166 params = params / 1e6 167 print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") 168 169 170def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training): 171 if n_iterations is None: 172 trainer_fit_params = {"epochs": n_epochs} 173 else: 174 trainer_fit_params = {"iterations": n_iterations} 175 176 if save_every_kth_epoch is not None: 177 trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch 178 179 if pbar_signals is not None: 180 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) 181 trainer_fit_params["progress"] = progress_bar_wrapper 182 183 # Avoid overwriting a trained model, if desired by the user. 184 trainer_fit_params["overwrite_training"] = overwrite_training 185 return trainer_fit_params 186 187 188def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): 189 optimizer = optimizer_class(model_params, lr=lr) 190 if scheduler_kwargs is None: 191 scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3} 192 scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) 193 return optimizer, scheduler 194 195 196def train_sam( 197 name: str, 198 model_type: str, 199 train_loader: DataLoader, 200 val_loader: DataLoader, 201 n_epochs: int = 100, 202 early_stopping: Optional[int] = 10, 203 n_objects_per_batch: Optional[int] = 25, 204 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 205 with_segmentation_decoder: bool = True, 206 freeze: Optional[List[str]] = None, 207 device: Optional[Union[str, torch.device]] = None, 208 lr: float = 1e-5, 209 n_sub_iteration: int = 8, 210 save_root: Optional[Union[str, os.PathLike]] = None, 211 mask_prob: float = 0.5, 212 n_iterations: Optional[int] = None, 213 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 214 scheduler_kwargs: Optional[Dict[str, Any]] = None, 215 save_every_kth_epoch: Optional[int] = None, 216 pbar_signals: Optional[QObject] = None, 217 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 218 peft_kwargs: Optional[Dict] = None, 219 ignore_warnings: bool = True, 220 verify_n_labels_in_loader: Optional[int] = 50, 221 box_distortion_factor: Optional[float] = 0.025, 222 overwrite_training: bool = True, 223 strict_decoder_loading: bool = True, 224 **model_kwargs, 225) -> None: 226 """Run training for a SAM model. 227 228 Args: 229 name: The name of the model to be trained. The checkpoint and logs will have this name. 230 model_type: The type of the SAM model. 231 train_loader: The dataloader for training. 232 val_loader: The dataloader for validation. 233 n_epochs: The number of epochs to train for. 234 early_stopping: Enable early stopping after this number of epochs without improvement. 235 By default, the value is set to '10' epochs. 236 n_objects_per_batch: The number of objects per batch used to compute 237 the loss for interative segmentation. If None all objects will be used, 238 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 239 checkpoint_path: Path to checkpoint for initializing the SAM model. 240 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 241 By default, trains with the additional instance segmentation decoder. 242 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 243 By default nothing is frozen and the full model is updated. 244 device: The device to use for training. By default, automatically chooses the best available device to train. 245 lr: The learning rate. By default, set to '1e-5'. 246 n_sub_iteration: The number of iterative prompts per training iteration. 247 By default, the number of iterations is set to '8'. 248 save_root: Optional root directory for saving the checkpoints and logs. 249 If not given the current working directory is used. 250 mask_prob: The probability for using a mask as input in a given training sub-iteration. 251 By default, set to '0.5'. 252 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 253 scheduler_class: The learning rate scheduler to update the learning rate. 254 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 255 scheduler_kwargs: The learning rate scheduler parameters. 256 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 257 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 258 pbar_signals: Controls for napari progress bar. 259 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 260 peft_kwargs: Keyword arguments for the PEFT wrapper class. 261 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 262 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 263 By default, 50 batches of labels are verified from the dataloaders. 264 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 265 By default, the distortion factor is set to '0.025'. 266 overwrite_training: Whether to overwrite the trained model stored at the same location. 267 By default, overwrites the trained model at each run. 268 If set to 'False', it will avoid retraining the model if the previous run was completed. 269 strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, 270 exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output 271 channels if they were pre-trained for a different task. If set to False, decoders with a different 272 output dimension can be loaded; the output channels will be re-initialized. 273 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 274 """ 275 with _filter_warnings(ignore_warnings): 276 277 t_start = time.time() 278 if verify_n_labels_in_loader is not None: 279 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 280 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 281 282 device = get_device(device) 283 # Get the trainable segment anything model. 284 model, state = get_trainable_sam_model( 285 model_type=model_type, 286 device=device, 287 freeze=freeze, 288 checkpoint_path=checkpoint_path, 289 return_state=True, 290 peft_kwargs=peft_kwargs, 291 **model_kwargs 292 ) 293 294 # This class creates all the training data for a batch (inputs, prompts and labels). 295 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 296 297 # Create the UNETR decoder (if train with it) and the optimizer. 298 if with_segmentation_decoder: 299 300 # Get the UNETR. 301 unetr = get_unetr( 302 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), 303 device=device, flexible_load_checkpoint=not strict_decoder_loading, 304 ) 305 306 # Get the parameters for SAM and the decoder from UNETR. 307 joint_model_params = [params for params in model.parameters()] # sam parameters 308 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 309 if not param_name.startswith("encoder"): 310 joint_model_params.append(params) 311 312 model_params = joint_model_params 313 else: 314 model_params = model.parameters() 315 316 optimizer, scheduler = _get_optimizer_and_scheduler( 317 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 318 ) 319 320 # The trainer which performs training and validation. 321 if with_segmentation_decoder: 322 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 323 trainer = joint_trainers.JointSamTrainer( 324 name=name, 325 save_root=save_root, 326 train_loader=train_loader, 327 val_loader=val_loader, 328 model=model, 329 optimizer=optimizer, 330 device=device, 331 lr_scheduler=scheduler, 332 logger=joint_trainers.JointSamLogger, 333 log_image_interval=100, 334 mixed_precision=True, 335 convert_inputs=convert_inputs, 336 n_objects_per_batch=n_objects_per_batch, 337 n_sub_iteration=n_sub_iteration, 338 compile_model=False, 339 unetr=unetr, 340 instance_loss=instance_seg_loss, 341 instance_metric=instance_seg_loss, 342 early_stopping=early_stopping, 343 mask_prob=mask_prob, 344 ) 345 else: 346 trainer = trainers.SamTrainer( 347 name=name, 348 train_loader=train_loader, 349 val_loader=val_loader, 350 model=model, 351 optimizer=optimizer, 352 device=device, 353 lr_scheduler=scheduler, 354 logger=trainers.SamLogger, 355 log_image_interval=100, 356 mixed_precision=True, 357 convert_inputs=convert_inputs, 358 n_objects_per_batch=n_objects_per_batch, 359 n_sub_iteration=n_sub_iteration, 360 compile_model=False, 361 early_stopping=early_stopping, 362 mask_prob=mask_prob, 363 save_root=save_root, 364 ) 365 366 trainer_fit_params = _get_trainer_fit_params( 367 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 368 ) 369 trainer.fit(**trainer_fit_params) 370 371 t_run = time.time() - t_start 372 hours = int(t_run // 3600) 373 minutes = int(t_run // 60) 374 seconds = int(round(t_run % 60, 0)) 375 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 376 377 378def export_instance_segmentation_model( 379 trained_model_path: Union[str, os.PathLike], 380 output_path: Union[str, os.PathLike], 381 model_type: str, 382 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 383) -> None: 384 """Export a model trained for instance segmentation with `train_instance_segmentation`. 385 386 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 387 It should only be used for automatic segmentation and may not work well for interactive segmentation. 388 389 Args: 390 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 391 output_path: The path where the exported model will be saved. 392 model_type: The model type. 393 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 394 """ 395 state = torch.load(trained_model_path, weights_only=False, map_location="cpu") 396 trained_state = state.get("model_state", state) 397 398 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 399 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 400 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 401 402 # Load the original state of the model that was used as the basis of instance segmentation training. 403 predictor = get_sam_model( 404 model_type=model_type, checkpoint_path=initial_checkpoint_path, device="cpu", 405 ) 406 model_state = OrderedDict(predictor.model.state_dict()) 407 408 # Replace the image encoder weights with the trained ones. 409 # UNETR stores the image encoder as "encoder.*"; SAM uses "image_encoder.*". 410 for k in list(model_state.keys()): 411 if k.startswith("image_encoder."): 412 encoder_key = "encoder." + k[len("image_encoder."):] 413 model_state[k] = encoder_state[encoder_key] 414 415 save_state = {"model_state": model_state, "decoder_state": decoder_state} 416 torch.save(save_state, output_path) 417 418 419def train_instance_segmentation( 420 name: str, 421 model_type: str, 422 train_loader: DataLoader, 423 val_loader: DataLoader, 424 n_epochs: int = 100, 425 early_stopping: Optional[int] = 10, 426 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 427 metric: Optional[torch.nn.Module] = None, 428 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 429 freeze: Optional[List[str]] = None, 430 device: Optional[Union[str, torch.device]] = None, 431 lr: float = 1e-5, 432 save_root: Optional[Union[str, os.PathLike]] = None, 433 n_iterations: Optional[int] = None, 434 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 435 scheduler_kwargs: Optional[Dict[str, Any]] = None, 436 save_every_kth_epoch: Optional[int] = None, 437 pbar_signals: Optional[QObject] = None, 438 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 439 peft_kwargs: Optional[Dict] = None, 440 ignore_warnings: bool = True, 441 overwrite_training: bool = True, 442 strict_decoder_loading: bool = True, 443 **model_kwargs, 444) -> None: 445 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 446 447 This setting corresponds to training a SAM model with an instance segmentation decoder, 448 without training the model parts for interactive segmentation, 449 i.e. without training the prompt encoder and mask decoder. 450 451 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 452 will not be compatible with the micro_sam functionality. 453 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 454 in a format that is compatible with micro_sam functionality. 455 Note that the exported model should only be used for automatic segmentation via AIS. 456 457 Args: 458 name: The name of the model to be trained. The checkpoint and logs will have this name. 459 model_type: The type of the SAM model. 460 train_loader: The dataloader for training. 461 val_loader: The dataloader for validation. 462 n_epochs: The number of epochs to train for. 463 early_stopping: Enable early stopping after this number of epochs without improvement. 464 By default, the value is set to '10' epochs. 465 loss: The loss function to train the instance segmentation model. 466 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 467 metric: The metric for the instance segmentation training. 468 By default the loss function is used as the metric. 469 checkpoint_path: Path to checkpoint for initializing the SAM model. 470 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 471 By default nothing is frozen and the full model is updated. 472 device: The device to use for training. By default, automatically chooses the best available device to train. 473 lr: The learning rate. By default, set to '1e-5'. 474 save_root: Optional root directory for saving the checkpoints and logs. 475 If not given the current working directory is used. 476 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 477 scheduler_class: The learning rate scheduler to update the learning rate. 478 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 479 scheduler_kwargs: The learning rate scheduler parameters. 480 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 481 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 482 pbar_signals: Controls for napari progress bar. 483 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 484 peft_kwargs: Keyword arguments for the PEFT wrapper class. 485 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 486 overwrite_training: Whether to overwrite the trained model stored at the same location. 487 By default, overwrites the trained model at each run. 488 If set to 'False', it will avoid retraining the model if the previous run was completed. 489 strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, 490 exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output 491 channels if they were pre-trained for a different task. If set to False, decoders with a different 492 output dimension can be loaded; the output channels will be re-initialized. 493 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 494 """ 495 496 with _filter_warnings(ignore_warnings): 497 t_start = time.time() 498 499 sam_model, state = get_trainable_sam_model( 500 model_type=model_type, 501 device=device, 502 checkpoint_path=checkpoint_path, 503 return_state=True, 504 peft_kwargs=peft_kwargs, 505 freeze=freeze, 506 **model_kwargs 507 ) 508 device = get_device(device) 509 model = get_unetr( 510 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), 511 device=device, flexible_load_checkpoint=not strict_decoder_loading, 512 ) 513 514 optimizer, scheduler = _get_optimizer_and_scheduler( 515 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 516 ) 517 trainer = torch_em.trainer.DefaultTrainer( 518 name=name, 519 model=model, 520 train_loader=train_loader, 521 val_loader=val_loader, 522 device=device, 523 mixed_precision=True, 524 log_image_interval=50, 525 compile_model=False, 526 save_root=save_root, 527 loss=loss, 528 metric=loss if metric is None else metric, 529 optimizer=optimizer, 530 lr_scheduler=scheduler, 531 early_stopping=early_stopping, 532 ) 533 534 trainer_fit_params = _get_trainer_fit_params( 535 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 536 ) 537 trainer.fit(**trainer_fit_params) 538 539 t_run = time.time() - t_start 540 hours = int(t_run // 3600) 541 minutes = int(t_run // 60) 542 seconds = int(round(t_run % 60, 0)) 543 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") 544 545 546def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): 547 if isinstance(raw_paths, (str, os.PathLike)): 548 path = raw_paths 549 else: 550 path = raw_paths[0] 551 assert isinstance(path, (str, os.PathLike, np.ndarray, torch.Tensor)) 552 553 # Check the underlying data dimensionality. 554 if raw_key is None and isinstance(path, (np.ndarray, torch.Tensor)): 555 ndim = path.ndim 556 elif raw_key is None: # If no key is given and this is not a tensor/array then we assume it's an image file. 557 ndim = imageio.imread(path).ndim 558 else: # Otherwise we try to open the file from key. 559 try: # First try to open it with elf. 560 with open_file(path, "r") as f: 561 ndim = f[raw_key].ndim 562 except ValueError: # This may fail for images in a folder with different sizes. 563 # In that case we read one of the images. 564 image_path = glob(os.path.join(path, raw_key))[0] 565 ndim = imageio.imread(image_path).ndim 566 567 if not isinstance(patch_shape, tuple): 568 patch_shape = tuple(patch_shape) 569 570 if ndim == 2: 571 assert len(patch_shape) == 2 572 return patch_shape 573 elif ndim == 3 and len(patch_shape) == 2 and not with_channels: 574 return (1,) + patch_shape 575 elif ndim == 4 and len(patch_shape) == 2 and with_channels: 576 return (1,) + patch_shape 577 else: 578 return patch_shape 579 580 581def _determine_dataset_and_channels(raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs): 582 # By default, let the 'default_segmentation_dataset' heuristic decide for itself. 583 is_seg_dataset = kwargs.pop("is_seg_dataset", None) 584 585 # Check if the input data is in numpy or torch. In this case determine we use 586 # the image collection dataset heuristic (torch_em will figure it out), 587 # and we determine the number of channels. 588 if isinstance(raw_paths, list) and isinstance(raw_paths[0], (np.ndarray, torch.Tensor)): 589 is_seg_dataset = False 590 if with_channels is None: 591 with_channels = raw_paths[0].ndim == 3 and raw_paths.shape[-1] == 3 592 return is_seg_dataset, with_channels 593 594 # Check if the raw inputs are RGB or not. If yes, use 'ImageCollectionDataset'. 595 # Get valid raw paths to make checks possible. 596 if raw_key and "*" in raw_key: # Use the wildcard pattern to find the filepath to only one image. 597 rpath = glob(os.path.join(raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key))[0] 598 else: # Otherwise, either 'raw_key' is None or container format, supported by 'elf', then we load 1 filepath. 599 rpath = raw_paths if isinstance(raw_paths, str) else raw_paths[0] 600 601 # Load one of the raw inputs to validate whether it is RGB or not. 602 test_raw_inputs = load_data(path=rpath, key=raw_key if raw_key and "*" not in raw_key else None) 603 if test_raw_inputs.ndim == 3: 604 if test_raw_inputs.shape[-1] == 3: # i.e. if it is an RGB image and has channels last. 605 is_seg_dataset = False # we use 'ImageCollectionDataset' in this case. 606 # We need to provide a list of inputs to 'ImageCollectionDataset'. 607 raw_paths = [raw_paths] if isinstance(raw_paths, str) else raw_paths 608 label_paths = [label_paths] if isinstance(label_paths, str) else label_paths 609 610 # This is not relevant for 'ImageCollectionDataset'. Hence, we set 'with_channels' to 'False'. 611 with_channels = False if with_channels is None else with_channels 612 613 elif test_raw_inputs.shape[0] == 3: # i.e. if it is a RGB image and has 3 channels first. 614 # This is relevant for 'SegmentationDataset'. If not provided by the user, we set this to 'True'. 615 with_channels = True if with_channels is None else with_channels 616 617 # Set 'with_channels' to 'False', i.e. the default behavior of 'default_segmentation_dataset' 618 # Otherwise, let the user make the choice as priority, else set this to our suggested default. 619 with_channels = False if with_channels is None else with_channels 620 621 return is_seg_dataset, with_channels 622 623 624def default_sam_dataset( 625 raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 626 raw_key: Optional[str], 627 label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 628 label_key: Optional[str], 629 patch_shape: Tuple[int], 630 with_segmentation_decoder: bool, 631 with_channels: Optional[bool] = None, 632 train_instance_segmentation_only: bool = False, 633 sampler: Optional[Callable] = None, 634 raw_transform: Optional[Callable] = None, 635 n_samples: Optional[int] = None, 636 is_train: bool = True, 637 min_size: int = 25, 638 max_sampling_attempts: Optional[int] = None, 639 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 640 is_multi_tensor: bool = True, 641 **kwargs, 642) -> Dataset: 643 """Create a PyTorch Dataset for training a SAM model. 644 645 Args: 646 raw_paths: The path(s) to the image data used for training. 647 Can either be multiple 2D images or volumetric data. 648 The data can also be passed as a list of numpy arrays or torch tensors. 649 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 650 or a glob pattern for selecting multiple files. 651 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 652 label_paths: The path(s) to the label data used for training. 653 Can either be multiple 2D images or volumetric data. 654 The data can also be passed as a list of numpy arrays or torch tensors. 655 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 656 or a glob pattern for selecting multiple files. 657 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 658 patch_shape: The shape for training patches. 659 with_segmentation_decoder: Whether to train with additional segmentation decoder. 660 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 661 train_instance_segmentation_only: Set this argument to True in order to 662 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 663 sampler: A sampler to reject batches according to a given criterion. 664 raw_transform: Transformation applied to the image data. 665 If not given the data will be cast to 8bit. 666 n_samples: The number of samples for this dataset. 667 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 668 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 669 max_sampling_attempts: Number of sampling attempts to make from a dataset. 670 rois: The region of interest(s) for the data. 671 is_multi_tensor: Whether the input data to data transforms is multiple tensors or not. 672 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 673 674 Returns: 675 The segmentation dataset. 676 """ 677 678 # Check if this dataset should be used for instance segmentation only training. 679 # If yes, we set return_instances to False, since the instance channel must not 680 # be passed for this training mode. 681 return_instances = True 682 if train_instance_segmentation_only: 683 if not with_segmentation_decoder: 684 raise ValueError( 685 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 686 ) 687 return_instances = False 688 689 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 690 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 691 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 692 # and we do not set a default sampler. 693 if sampler is None and not train_instance_segmentation_only: 694 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 695 696 is_seg_dataset, with_channels = _determine_dataset_and_channels( 697 raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs 698 ) 699 # Since `is_seg_dataset` is selected and sorted above, let's remove it out of running kwargs, if available. 700 kwargs.pop("is_seg_dataset", None) 701 702 # Set the data transformations. 703 if raw_transform is None: 704 raw_transform = require_8bit 705 706 # Prepare the label transform. 707 if with_segmentation_decoder: 708 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 709 distances=True, 710 boundary_distances=True, 711 directed_distances=False, 712 foreground=True, 713 instances=return_instances, 714 min_size=min_size, 715 ) 716 else: 717 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 718 719 # Allow combining label transforms. 720 custom_label_transform = kwargs.pop("label_transform", None) 721 if custom_label_transform is None: 722 label_transform = default_label_transform 723 else: 724 label_transform = torch_em.transform.generic.Compose( 725 custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor 726 ) 727 728 # Check the patch shape to add a singleton if required. 729 patch_shape = _update_patch_shape( 730 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 731 ) 732 733 # Set a minimum number of samples per epoch. 734 if n_samples is None: 735 loader = torch_em.default_segmentation_loader( 736 raw_paths=raw_paths, 737 raw_key=raw_key, 738 label_paths=label_paths, 739 label_key=label_key, 740 batch_size=1, 741 patch_shape=patch_shape, 742 with_channels=with_channels, 743 ndim=2, 744 is_seg_dataset=is_seg_dataset, 745 raw_transform=raw_transform, 746 rois=rois, 747 **kwargs 748 ) 749 n_samples = max(len(loader), 100 if is_train else 5) 750 751 dataset = torch_em.default_segmentation_dataset( 752 raw_paths=raw_paths, 753 raw_key=raw_key, 754 label_paths=label_paths, 755 label_key=label_key, 756 patch_shape=patch_shape, 757 raw_transform=raw_transform, 758 label_transform=label_transform, 759 with_channels=with_channels, 760 ndim=2, 761 sampler=sampler, 762 n_samples=n_samples, 763 is_seg_dataset=is_seg_dataset, 764 rois=rois, 765 **kwargs, 766 ) 767 768 if max_sampling_attempts is not None: 769 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 770 for ds in dataset.datasets: 771 ds.max_sampling_attempts = max_sampling_attempts 772 else: 773 dataset.max_sampling_attempts = max_sampling_attempts 774 775 return dataset 776 777 778def default_sam_loader(**kwargs) -> DataLoader: 779 """Create a PyTorch DataLoader for training a SAM model. 780 781 Args: 782 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 783 784 Returns: 785 The DataLoader. 786 """ 787 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 788 789 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 790 # which the users can provide to get their desired segmentation dataset. 791 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 792 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 793 794 ds = default_sam_dataset(**ds_kwargs) 795 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs) 796 797 798CONFIGURATIONS = { 799 "Minimal": {"model_type": "vit_t", "n_objects_per_batch": 4, "n_sub_iteration": 4}, 800 "CPU": {"model_type": "vit_b", "n_objects_per_batch": 10}, 801 "gtx1080": {"model_type": "vit_t", "n_objects_per_batch": 5}, 802 "gtx3080": { 803 "model_type": "vit_b", "n_objects_per_batch": 5, 804 "peft_kwargs": {"attention_layers_to_update": [11], "peft_module": ClassicalSurgery} 805 }, 806 "rtx5000": {"model_type": "vit_b", "n_objects_per_batch": 10}, 807 "V100": {"model_type": "vit_b"}, 808 "A100": {"model_type": "vit_h"}, 809} 810"""Best training configurations for given hardware resources. 811""" 812 813 814def _find_best_configuration(): 815 if torch.cuda.is_available(): 816 817 # Check how much memory we have and select the best matching GPU 818 # for the available VRAM size. 819 _, vram = torch.cuda.mem_get_info() 820 vram = vram / 1e9 # in GB 821 822 # Maybe we can get more configurations in the future. 823 if vram > 80: # More than 80 GB: use the A100 configurations. 824 return "A100" 825 elif vram > 30: # More than 30 GB: use the V100 configurations. 826 return "V100" 827 elif vram > 14: # More than 14 GB: use the RTX5000 configurations. 828 return "rtx5000" 829 elif vram > 8: # More than 8 GB: use the GTX3080 configurations. 830 return "gtx3080" 831 else: # Otherwise: not enough memory to train on the GPU, use CPU instead. 832 return "CPU" 833 else: 834 return "CPU" 835 836 837def train_sam_for_configuration( 838 name: str, 839 train_loader: DataLoader, 840 val_loader: DataLoader, 841 configuration: Optional[str] = None, 842 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 843 with_segmentation_decoder: bool = True, 844 train_instance_segmentation_only: bool = False, 845 model_type: Optional[str] = None, 846 **kwargs, 847) -> None: 848 """Run training for a SAM model with the configuration for a given hardware resource. 849 850 Selects the best training settings for the given configuration. 851 The available configurations are listed in `CONFIGURATIONS`. 852 853 Args: 854 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 855 train_loader: The dataloader for training. 856 val_loader: The dataloader for validation. 857 configuration: The configuration (= name of hardware resource). 858 By default, it is automatically selected for the best VRAM combination. 859 checkpoint_path: Path to checkpoint for initializing the SAM model. 860 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 861 By default, trains with the additional instance segmentation decoder. 862 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 863 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 864 model_type: Over-ride the default model type. 865 This can be used to use one of the micro_sam models as starting point 866 instead of a default sam model. 867 kwargs: Additional keyword parameters that will be passed to `train_sam`. 868 """ 869 if configuration is None: # Automatically choose based on available VRAM combination. 870 configuration = _find_best_configuration() 871 872 if configuration in CONFIGURATIONS: 873 train_kwargs = CONFIGURATIONS[configuration] 874 else: 875 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 876 877 if model_type is None: 878 model_type = train_kwargs.pop("model_type") 879 else: 880 expected_model_type = train_kwargs.pop("model_type") 881 if model_type[:5] != expected_model_type: 882 warnings.warn("You have specified a different model type.") 883 884 train_kwargs.update(**kwargs) 885 if train_instance_segmentation_only: 886 instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs) 887 model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs) 888 instance_seg_kwargs.update(**model_kwargs) 889 890 train_instance_segmentation( 891 name=name, 892 train_loader=train_loader, 893 val_loader=val_loader, 894 checkpoint_path=checkpoint_path, 895 model_type=model_type, 896 **instance_seg_kwargs, 897 ) 898 else: 899 train_sam( 900 name=name, 901 train_loader=train_loader, 902 val_loader=val_loader, 903 checkpoint_path=checkpoint_path, 904 with_segmentation_decoder=with_segmentation_decoder, 905 model_type=model_type, 906 **train_kwargs 907 ) 908 909 910def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): 911 912 # Whether the model is stored in the current working directory or in another location. 913 if save_root is None: 914 save_root = os.getcwd() # Map this to current working directory, if not specified by the user. 915 916 # Get the 'best' model checkpoint ready for export. 917 best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") 918 if not os.path.exists(best_checkpoint): 919 raise FileNotFoundError(f"The trained model not found at the expected location: '{best_checkpoint}'.") 920 921 # Export the model if an output path has been given. 922 if output_path: 923 924 # If the filepath has a pytorch-specific ending, then we just export the checkpoint. 925 if os.path.splitext(output_path)[1] in (".pt", ".pth"): 926 export_custom_sam_model( 927 checkpoint_path=best_checkpoint, 928 model_type=model_type[:5], 929 save_path=output_path, 930 with_segmentation_decoder=with_segmentation_decoder, 931 ) 932 933 # Otherwise we export it as bioimage.io model. 934 else: 935 from micro_sam.bioimageio import export_sam_model 936 937 # Load image and corresponding labels from the val loader. 938 with torch.no_grad(): 939 image_data, label_data = next(iter(val_loader)) 940 image_data, label_data = image_data.numpy().squeeze(), label_data.numpy().squeeze() 941 942 # Select the first channel of the label image if we have a channel axis, i.e. contains the labels 943 if label_data.ndim == 3: 944 label_data = label_data[0] # Gets the channel with instances. 945 assert image_data.shape == label_data.shape 946 label_data = label_data.astype("uint32") 947 948 export_sam_model( 949 image=image_data, 950 label_image=label_data, 951 model_type=model_type[:5], 952 name=checkpoint_name, 953 output_path=output_path, 954 checkpoint_path=best_checkpoint, 955 ) 956 957 # The final path where the model has been stored. 958 final_path = output_path 959 960 else: # If no exports have been made, inform the user about the best checkpoint. 961 final_path = best_checkpoint 962 963 return final_path 964 965 966def _parse_segmentation_decoder(segmentation_decoder): 967 if segmentation_decoder in ("None", "none"): 968 with_segmentation_decoder, train_instance_segmentation_only = False, False 969 elif segmentation_decoder == "instances": 970 with_segmentation_decoder, train_instance_segmentation_only = True, False 971 elif segmentation_decoder == "instances_only": 972 with_segmentation_decoder, train_instance_segmentation_only = True, True 973 else: 974 raise ValueError( 975 "The 'segmentation_decoder' argument currently supports the values:\n" 976 f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}." 977 ) 978 return with_segmentation_decoder, train_instance_segmentation_only 979 980 981def main(): 982 """@private""" 983 import argparse 984 985 available_models = list(get_model_names()) 986 available_models = ", ".join(available_models) 987 988 available_configurations = list(CONFIGURATIONS.keys()) 989 available_configurations = ", ".join(available_configurations) 990 991 parser = argparse.ArgumentParser(description="Finetune Segment Anything Models on custom data.") 992 993 # Images and labels for training. 994 parser.add_argument( 995 "--images", required=True, type=str, nargs="*", 996 help="Filepath to images or the directory where the image data is stored." 997 ) 998 parser.add_argument( 999 "--labels", required=True, type=str, nargs="*", 1000 help="Filepath to ground-truth labels or the directory where the label data is stored." 1001 ) 1002 parser.add_argument( 1003 "--image_key", type=str, default=None, 1004 help="The key for accessing image data, either a pattern / wildcard or with elf.io.open_file. " 1005 ) 1006 parser.add_argument( 1007 "--label_key", type=str, default=None, 1008 help="The key for accessing label data, either a pattern / wildcard or with elf.io.open_file. " 1009 ) 1010 1011 # Images and labels for validation. 1012 # NOTE: This isn't required, i.e. we create a val-split on-the-fly from the training data if not provided. 1013 # Users can choose to have their explicit validation set via this feature as well. 1014 parser.add_argument( 1015 "--val_images", type=str, nargs="*", 1016 help="Filepath to images for validation or the directory where the image data is stored." 1017 ) 1018 parser.add_argument( 1019 "--val_labels", type=str, nargs="*", 1020 help="Filepath to ground-truth labels for validation or the directory where the label data is stored." 1021 ) 1022 parser.add_argument( 1023 "--val_image_key", type=str, default=None, 1024 help="The key for accessing image data for validation, either a pattern / wildcard or with elf.io.open_file." 1025 ) 1026 parser.add_argument( 1027 "--val_label_key", type=str, default=None, 1028 help="The key for accessing label data for validation, either a pattern / wildcard or with elf.io.open_file." 1029 ) 1030 1031 # Other necessary stuff for training. 1032 parser.add_argument( 1033 "--configuration", type=str, default=_find_best_configuration(), 1034 help=f"The configuration for finetuning the Segment Anything Model, one of {available_configurations}." 1035 ) 1036 1037 def none_or_str(value): 1038 if value.lower() == 'none': 1039 return None 1040 return value 1041 1042 # This could be extended to train for semantic segmentation or other options. 1043 parser.add_argument( 1044 "--segmentation_decoder", type=none_or_str, default="instances", 1045 help="Whether to finetune Segment Anything Model with an additional segmentation decoder. " 1046 "The following options are possible:\n" 1047 "- 'instances' to train with an additional decoder for automatic instance segmentation. " 1048 " This option enables using the automatic instance segmentation (AIS) mode.\n" 1049 "- 'instances_only' to train only the instance segmentation decoder. " 1050 " In this case the parts of SAM that are used for interactive segmentation will not be trained.\n" 1051 "- 'None' to train without an additional segmentation decoder." 1052 " This options trains only the parts of the original SAM.\n" 1053 "By default the option 'instances' is used." 1054 ) 1055 1056 # Optional advanced settings a user can opt to change the values for. 1057 parser.add_argument( 1058 "-d", "--device", type=str, default=None, 1059 help="The device to use for finetuning. Can be one of 'cuda', 'cpu' or 'mps' (only MAC). " 1060 "By default the most performant available device will be selected." 1061 ) 1062 parser.add_argument( 1063 "--patch_shape", type=int, nargs="*", default=(512, 512), 1064 help="The choice of patch shape for training Segment Anything Model. " 1065 "By default, a patch size of 512x512 is used." 1066 ) 1067 parser.add_argument( 1068 "-m", "--model_type", type=str, default=None, 1069 help=f"The Segment Anything Model that will be used for finetuning, one of {available_models}." 1070 ) 1071 parser.add_argument( 1072 "--checkpoint_path", type=str, default=None, 1073 help="Checkpoint from which the SAM model will be loaded for finetuning." 1074 ) 1075 parser.add_argument( 1076 "-s", "--save_root", type=str, default=None, 1077 help="The directory where the trained models and corresponding logs will be stored. " 1078 "By default, there are stored in your current working directory." 1079 ) 1080 parser.add_argument( 1081 "--trained_model_name", type=str, default="sam_model", 1082 help="The custom name of trained model sub-folder. Allows users to have several trained models " 1083 "under the same 'save_root'." 1084 ) 1085 parser.add_argument( 1086 "--output_path", type=str, default=None, 1087 help="The directory (eg. '/path/to/folder') or filepath (eg. '/path/to/model.pt') to export the trained model." 1088 ) 1089 parser.add_argument( 1090 "--n_epochs", type=int, default=100, 1091 help="The total number of epochs to train the Segment Anything Model. By default, trains for 100 epochs." 1092 ) 1093 parser.add_argument( 1094 "--num_workers", type=int, default=1, help="The number of workers for processing data with dataloaders." 1095 ) 1096 parser.add_argument( 1097 "--batch_size", type=int, default=1, 1098 help="The choice of batch size for training the Segment Anything Model. By default the batch size is set to 1." 1099 ) 1100 parser.add_argument( 1101 "--preprocess", type=str, default=None, choices=("normalize_minmax", "normalize_percentile"), 1102 help="Whether to normalize the raw inputs. By default, does not perform any preprocessing of input images " 1103 "Otherwise, choose from either 'normalize_percentile' or 'normalize_minmax'." 1104 ) 1105 1106 args = parser.parse_args() 1107 1108 # 1. Get all necessary stuff for training. 1109 checkpoint_name = args.trained_model_name 1110 config = args.configuration 1111 model_type = args.model_type 1112 checkpoint_path = args.checkpoint_path 1113 batch_size = args.batch_size 1114 patch_shape = args.patch_shape 1115 epochs = args.n_epochs 1116 num_workers = args.num_workers 1117 device = args.device 1118 save_root = args.save_root 1119 output_path = args.output_path 1120 with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder) 1121 1122 # Get image paths and corresponding keys. 1123 train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key 1124 val_images, val_gt, val_image_key, val_gt_key = args.val_images, args.val_labels, args.val_image_key, args.val_label_key # noqa 1125 1126 # 2. Prepare the dataloaders. 1127 1128 # If the user wants to preprocess the inputs, we allow the possibility to do so. 1129 _raw_transform = get_raw_transform(args.preprocess) 1130 1131 # Get the dataset with files for training. 1132 dataset = default_sam_dataset( 1133 raw_paths=train_images, 1134 raw_key=train_image_key, 1135 label_paths=train_gt, 1136 label_key=train_gt_key, 1137 patch_shape=patch_shape, 1138 with_segmentation_decoder=with_segmentation_decoder, 1139 raw_transform=_raw_transform, 1140 train_instance_segmentation_only=train_instance_segmentation_only, 1141 ) 1142 1143 # If val images are not exclusively provided, we create a val split from the training data. 1144 if val_images is None: 1145 assert val_gt is None and val_image_key is None and val_gt_key is None 1146 # Use 10% of the dataset for validation - at least one image - for validation. 1147 n_val = max(1, int(0.1 * len(dataset))) 1148 train_dataset, val_dataset = random_split(dataset, lengths=[len(dataset) - n_val, n_val]) 1149 1150 else: # If val images provided, we create a new dataset for it. 1151 train_dataset = dataset 1152 val_dataset = default_sam_dataset( 1153 raw_paths=val_images, 1154 raw_key=val_image_key, 1155 label_paths=val_gt, 1156 label_key=val_gt_key, 1157 patch_shape=patch_shape, 1158 with_segmentation_decoder=with_segmentation_decoder, 1159 train_instance_segmentation_only=train_instance_segmentation_only, 1160 raw_transform=_raw_transform, 1161 ) 1162 1163 # Get the dataloaders from the datasets. 1164 train_loader = torch_em.get_data_loader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 1165 val_loader = torch_em.get_data_loader(val_dataset, batch_size=1, shuffle=True, num_workers=num_workers) 1166 1167 # 3. Train the Segment Anything Model. 1168 1169 # Get a valid model and other necessary parameters for training. 1170 if model_type is not None and model_type not in available_models: 1171 raise ValueError(f"'{model_type}' is not a valid choice of model.") 1172 if config is not None and config not in available_configurations: 1173 raise ValueError(f"'{config}' is not a valid choice of configuration.") 1174 1175 if model_type is None: # If user does not specify the model, we use the default model corresponding to the config. 1176 model_type = CONFIGURATIONS[config]["model_type"] 1177 1178 train_sam_for_configuration( 1179 name=checkpoint_name, 1180 configuration=config, 1181 model_type=model_type, 1182 train_loader=train_loader, 1183 val_loader=val_loader, 1184 n_epochs=epochs, 1185 checkpoint_path=checkpoint_path, 1186 with_segmentation_decoder=with_segmentation_decoder, 1187 freeze=None, # TODO: Allow for PEFT. 1188 device=device, 1189 save_root=save_root, 1190 peft_kwargs=None, # TODO: Allow for PEFT. 1191 train_instance_segmentation_only=train_instance_segmentation_only, 1192 ) 1193 1194 # 4. Export the model, if desired by the user 1195 if train_instance_segmentation_only and output_path: 1196 trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt") 1197 export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path) 1198 final_path = output_path 1199 else: 1200 final_path = _export_helper( 1201 save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader, 1202 ) 1203 1204 print(f"Training has finished. The trained model is saved at {final_path}.")
197def train_sam( 198 name: str, 199 model_type: str, 200 train_loader: DataLoader, 201 val_loader: DataLoader, 202 n_epochs: int = 100, 203 early_stopping: Optional[int] = 10, 204 n_objects_per_batch: Optional[int] = 25, 205 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 206 with_segmentation_decoder: bool = True, 207 freeze: Optional[List[str]] = None, 208 device: Optional[Union[str, torch.device]] = None, 209 lr: float = 1e-5, 210 n_sub_iteration: int = 8, 211 save_root: Optional[Union[str, os.PathLike]] = None, 212 mask_prob: float = 0.5, 213 n_iterations: Optional[int] = None, 214 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 215 scheduler_kwargs: Optional[Dict[str, Any]] = None, 216 save_every_kth_epoch: Optional[int] = None, 217 pbar_signals: Optional[QObject] = None, 218 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 219 peft_kwargs: Optional[Dict] = None, 220 ignore_warnings: bool = True, 221 verify_n_labels_in_loader: Optional[int] = 50, 222 box_distortion_factor: Optional[float] = 0.025, 223 overwrite_training: bool = True, 224 strict_decoder_loading: bool = True, 225 **model_kwargs, 226) -> None: 227 """Run training for a SAM model. 228 229 Args: 230 name: The name of the model to be trained. The checkpoint and logs will have this name. 231 model_type: The type of the SAM model. 232 train_loader: The dataloader for training. 233 val_loader: The dataloader for validation. 234 n_epochs: The number of epochs to train for. 235 early_stopping: Enable early stopping after this number of epochs without improvement. 236 By default, the value is set to '10' epochs. 237 n_objects_per_batch: The number of objects per batch used to compute 238 the loss for interative segmentation. If None all objects will be used, 239 if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'. 240 checkpoint_path: Path to checkpoint for initializing the SAM model. 241 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 242 By default, trains with the additional instance segmentation decoder. 243 freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder 244 By default nothing is frozen and the full model is updated. 245 device: The device to use for training. By default, automatically chooses the best available device to train. 246 lr: The learning rate. By default, set to '1e-5'. 247 n_sub_iteration: The number of iterative prompts per training iteration. 248 By default, the number of iterations is set to '8'. 249 save_root: Optional root directory for saving the checkpoints and logs. 250 If not given the current working directory is used. 251 mask_prob: The probability for using a mask as input in a given training sub-iteration. 252 By default, set to '0.5'. 253 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 254 scheduler_class: The learning rate scheduler to update the learning rate. 255 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 256 scheduler_kwargs: The learning rate scheduler parameters. 257 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 258 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 259 pbar_signals: Controls for napari progress bar. 260 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 261 peft_kwargs: Keyword arguments for the PEFT wrapper class. 262 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 263 verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. 264 By default, 50 batches of labels are verified from the dataloaders. 265 box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. 266 By default, the distortion factor is set to '0.025'. 267 overwrite_training: Whether to overwrite the trained model stored at the same location. 268 By default, overwrites the trained model at each run. 269 If set to 'False', it will avoid retraining the model if the previous run was completed. 270 strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, 271 exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output 272 channels if they were pre-trained for a different task. If set to False, decoders with a different 273 output dimension can be loaded; the output channels will be re-initialized. 274 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 275 """ 276 with _filter_warnings(ignore_warnings): 277 278 t_start = time.time() 279 if verify_n_labels_in_loader is not None: 280 _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) 281 _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) 282 283 device = get_device(device) 284 # Get the trainable segment anything model. 285 model, state = get_trainable_sam_model( 286 model_type=model_type, 287 device=device, 288 freeze=freeze, 289 checkpoint_path=checkpoint_path, 290 return_state=True, 291 peft_kwargs=peft_kwargs, 292 **model_kwargs 293 ) 294 295 # This class creates all the training data for a batch (inputs, prompts and labels). 296 convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=box_distortion_factor) 297 298 # Create the UNETR decoder (if train with it) and the optimizer. 299 if with_segmentation_decoder: 300 301 # Get the UNETR. 302 unetr = get_unetr( 303 image_encoder=model.sam.image_encoder, decoder_state=state.get("decoder_state", None), 304 device=device, flexible_load_checkpoint=not strict_decoder_loading, 305 ) 306 307 # Get the parameters for SAM and the decoder from UNETR. 308 joint_model_params = [params for params in model.parameters()] # sam parameters 309 for param_name, params in unetr.named_parameters(): # unetr's decoder parameters 310 if not param_name.startswith("encoder"): 311 joint_model_params.append(params) 312 313 model_params = joint_model_params 314 else: 315 model_params = model.parameters() 316 317 optimizer, scheduler = _get_optimizer_and_scheduler( 318 model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs 319 ) 320 321 # The trainer which performs training and validation. 322 if with_segmentation_decoder: 323 instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) 324 trainer = joint_trainers.JointSamTrainer( 325 name=name, 326 save_root=save_root, 327 train_loader=train_loader, 328 val_loader=val_loader, 329 model=model, 330 optimizer=optimizer, 331 device=device, 332 lr_scheduler=scheduler, 333 logger=joint_trainers.JointSamLogger, 334 log_image_interval=100, 335 mixed_precision=True, 336 convert_inputs=convert_inputs, 337 n_objects_per_batch=n_objects_per_batch, 338 n_sub_iteration=n_sub_iteration, 339 compile_model=False, 340 unetr=unetr, 341 instance_loss=instance_seg_loss, 342 instance_metric=instance_seg_loss, 343 early_stopping=early_stopping, 344 mask_prob=mask_prob, 345 ) 346 else: 347 trainer = trainers.SamTrainer( 348 name=name, 349 train_loader=train_loader, 350 val_loader=val_loader, 351 model=model, 352 optimizer=optimizer, 353 device=device, 354 lr_scheduler=scheduler, 355 logger=trainers.SamLogger, 356 log_image_interval=100, 357 mixed_precision=True, 358 convert_inputs=convert_inputs, 359 n_objects_per_batch=n_objects_per_batch, 360 n_sub_iteration=n_sub_iteration, 361 compile_model=False, 362 early_stopping=early_stopping, 363 mask_prob=mask_prob, 364 save_root=save_root, 365 ) 366 367 trainer_fit_params = _get_trainer_fit_params( 368 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 369 ) 370 trainer.fit(**trainer_fit_params) 371 372 t_run = time.time() - t_start 373 hours = int(t_run // 3600) 374 minutes = int(t_run // 60) 375 seconds = int(round(t_run % 60, 0)) 376 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Run training for a SAM model.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- n_objects_per_batch: The number of objects per batch used to compute the loss for interative segmentation. If None all objects will be used, if given objects will be randomly sub-sampled. By default, the number of objects per batch are '25'.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- n_sub_iteration: The number of iterative prompts per training iteration. By default, the number of iterations is set to '8'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- mask_prob: The probability for using a mask as input in a given training sub-iteration. By default, set to '0.5'.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochsif given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateauis used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamWis used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. By default, 50 batches of labels are verified from the dataloaders.
- box_distortion_factor: The factor for distorting the box annotations derived from the ground-truth masks. By default, the distortion factor is set to '0.025'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output channels if they were pre-trained for a different task. If set to False, decoders with a different output dimension can be loaded; the output channels will be re-initialized.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model.
379def export_instance_segmentation_model( 380 trained_model_path: Union[str, os.PathLike], 381 output_path: Union[str, os.PathLike], 382 model_type: str, 383 initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, 384) -> None: 385 """Export a model trained for instance segmentation with `train_instance_segmentation`. 386 387 The exported model will be compatible with the micro_sam functions, CLI and napari plugin. 388 It should only be used for automatic segmentation and may not work well for interactive segmentation. 389 390 Args: 391 trained_model_path: The path to the checkpoint of the model trained for instance segmentation. 392 output_path: The path where the exported model will be saved. 393 model_type: The model type. 394 initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). 395 """ 396 state = torch.load(trained_model_path, weights_only=False, map_location="cpu") 397 trained_state = state.get("model_state", state) 398 399 # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. 400 encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) 401 decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) 402 403 # Load the original state of the model that was used as the basis of instance segmentation training. 404 predictor = get_sam_model( 405 model_type=model_type, checkpoint_path=initial_checkpoint_path, device="cpu", 406 ) 407 model_state = OrderedDict(predictor.model.state_dict()) 408 409 # Replace the image encoder weights with the trained ones. 410 # UNETR stores the image encoder as "encoder.*"; SAM uses "image_encoder.*". 411 for k in list(model_state.keys()): 412 if k.startswith("image_encoder."): 413 encoder_key = "encoder." + k[len("image_encoder."):] 414 model_state[k] = encoder_state[encoder_key] 415 416 save_state = {"model_state": model_state, "decoder_state": decoder_state} 417 torch.save(save_state, output_path)
Export a model trained for instance segmentation with train_instance_segmentation.
The exported model will be compatible with the micro_sam functions, CLI and napari plugin. It should only be used for automatic segmentation and may not work well for interactive segmentation.
Arguments:
- trained_model_path: The path to the checkpoint of the model trained for instance segmentation.
- output_path: The path where the exported model will be saved.
- model_type: The model type.
- initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional).
420def train_instance_segmentation( 421 name: str, 422 model_type: str, 423 train_loader: DataLoader, 424 val_loader: DataLoader, 425 n_epochs: int = 100, 426 early_stopping: Optional[int] = 10, 427 loss: torch.nn.Module = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True), 428 metric: Optional[torch.nn.Module] = None, 429 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 430 freeze: Optional[List[str]] = None, 431 device: Optional[Union[str, torch.device]] = None, 432 lr: float = 1e-5, 433 save_root: Optional[Union[str, os.PathLike]] = None, 434 n_iterations: Optional[int] = None, 435 scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, 436 scheduler_kwargs: Optional[Dict[str, Any]] = None, 437 save_every_kth_epoch: Optional[int] = None, 438 pbar_signals: Optional[QObject] = None, 439 optimizer_class: Optional[Optimizer] = torch.optim.AdamW, 440 peft_kwargs: Optional[Dict] = None, 441 ignore_warnings: bool = True, 442 overwrite_training: bool = True, 443 strict_decoder_loading: bool = True, 444 **model_kwargs, 445) -> None: 446 """Train a UNETR for instance segmentation using the SAM encoder as backbone. 447 448 This setting corresponds to training a SAM model with an instance segmentation decoder, 449 without training the model parts for interactive segmentation, 450 i.e. without training the prompt encoder and mask decoder. 451 452 The checkpoint of the trained model, which will be saved in 'checkpoints/<name>', 453 will not be compatible with the micro_sam functionality. 454 You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it 455 in a format that is compatible with micro_sam functionality. 456 Note that the exported model should only be used for automatic segmentation via AIS. 457 458 Args: 459 name: The name of the model to be trained. The checkpoint and logs will have this name. 460 model_type: The type of the SAM model. 461 train_loader: The dataloader for training. 462 val_loader: The dataloader for validation. 463 n_epochs: The number of epochs to train for. 464 early_stopping: Enable early stopping after this number of epochs without improvement. 465 By default, the value is set to '10' epochs. 466 loss: The loss function to train the instance segmentation model. 467 By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss' 468 metric: The metric for the instance segmentation training. 469 By default the loss function is used as the metric. 470 checkpoint_path: Path to checkpoint for initializing the SAM model. 471 freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. 472 By default nothing is frozen and the full model is updated. 473 device: The device to use for training. By default, automatically chooses the best available device to train. 474 lr: The learning rate. By default, set to '1e-5'. 475 save_root: Optional root directory for saving the checkpoints and logs. 476 If not given the current working directory is used. 477 n_iterations: The number of iterations to use for training. This will over-ride `n_epochs` if given. 478 scheduler_class: The learning rate scheduler to update the learning rate. 479 By default, `torch.optim.lr_scheduler.ReduceLROnPlateau` is used. 480 scheduler_kwargs: The learning rate scheduler parameters. 481 If passed 'None', the chosen default parameters are used in `ReduceLROnPlateau`. 482 save_every_kth_epoch: Save checkpoints after every kth epoch separately. 483 pbar_signals: Controls for napari progress bar. 484 optimizer_class: The optimizer class. By default, `torch.optim.AdamW` is used. 485 peft_kwargs: Keyword arguments for the PEFT wrapper class. 486 ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'. 487 overwrite_training: Whether to overwrite the trained model stored at the same location. 488 By default, overwrites the trained model at each run. 489 If set to 'False', it will avoid retraining the model if the previous run was completed. 490 strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, 491 exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output 492 channels if they were pre-trained for a different task. If set to False, decoders with a different 493 output dimension can be loaded; the output channels will be re-initialized. 494 model_kwargs: Additional keyword arguments for the `micro_sam.util.get_sam_model`. 495 """ 496 497 with _filter_warnings(ignore_warnings): 498 t_start = time.time() 499 500 sam_model, state = get_trainable_sam_model( 501 model_type=model_type, 502 device=device, 503 checkpoint_path=checkpoint_path, 504 return_state=True, 505 peft_kwargs=peft_kwargs, 506 freeze=freeze, 507 **model_kwargs 508 ) 509 device = get_device(device) 510 model = get_unetr( 511 image_encoder=sam_model.sam.image_encoder, decoder_state=state.get("decoder_state", None), 512 device=device, flexible_load_checkpoint=not strict_decoder_loading, 513 ) 514 515 optimizer, scheduler = _get_optimizer_and_scheduler( 516 model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs 517 ) 518 trainer = torch_em.trainer.DefaultTrainer( 519 name=name, 520 model=model, 521 train_loader=train_loader, 522 val_loader=val_loader, 523 device=device, 524 mixed_precision=True, 525 log_image_interval=50, 526 compile_model=False, 527 save_root=save_root, 528 loss=loss, 529 metric=loss if metric is None else metric, 530 optimizer=optimizer, 531 lr_scheduler=scheduler, 532 early_stopping=early_stopping, 533 ) 534 535 trainer_fit_params = _get_trainer_fit_params( 536 n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training 537 ) 538 trainer.fit(**trainer_fit_params) 539 540 t_run = time.time() - t_start 541 hours = int(t_run // 3600) 542 minutes = int(t_run // 60) 543 seconds = int(round(t_run % 60, 0)) 544 print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)")
Train a UNETR for instance segmentation using the SAM encoder as backbone.
This setting corresponds to training a SAM model with an instance segmentation decoder, without training the model parts for interactive segmentation, i.e. without training the prompt encoder and mask decoder.
The checkpoint of the trained model, which will be saved in 'checkpoints/export_instance_segmentation_model with the path to the checkpoint to export it
in a format that is compatible with micro_sam functionality.
Note that the exported model should only be used for automatic segmentation via AIS.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs will have this name.
- model_type: The type of the SAM model.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- n_epochs: The number of epochs to train for.
- early_stopping: Enable early stopping after this number of epochs without improvement. By default, the value is set to '10' epochs.
- loss: The loss function to train the instance segmentation model. By default, the value is set to 'torch_em.loss.DiceBasedDistanceLoss'
- metric: The metric for the instance segmentation training. By default the loss function is used as the metric.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. By default nothing is frozen and the full model is updated.
- device: The device to use for training. By default, automatically chooses the best available device to train.
- lr: The learning rate. By default, set to '1e-5'.
- save_root: Optional root directory for saving the checkpoints and logs. If not given the current working directory is used.
- n_iterations: The number of iterations to use for training. This will over-ride
n_epochsif given. - scheduler_class: The learning rate scheduler to update the learning rate.
By default,
torch.optim.lr_scheduler.ReduceLROnPlateauis used. - scheduler_kwargs: The learning rate scheduler parameters.
If passed 'None', the chosen default parameters are used in
ReduceLROnPlateau. - save_every_kth_epoch: Save checkpoints after every kth epoch separately.
- pbar_signals: Controls for napari progress bar.
- optimizer_class: The optimizer class. By default,
torch.optim.AdamWis used. - peft_kwargs: Keyword arguments for the PEFT wrapper class.
- ignore_warnings: Whether to ignore raised warnings. By default, set to 'True'.
- overwrite_training: Whether to overwrite the trained model stored at the same location. By default, overwrites the trained model at each run. If set to 'False', it will avoid retraining the model if the previous run was completed.
- strict_decoder_loading: Whether to require that the pre-trained decoder in the checkpoint, if present, exactly matches the instance segmentation decoder. Decoders may have a mismatch in the output channels if they were pre-trained for a different task. If set to False, decoders with a different output dimension can be loaded; the output channels will be re-initialized.
- model_kwargs: Additional keyword arguments for the
micro_sam.util.get_sam_model.
625def default_sam_dataset( 626 raw_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 627 raw_key: Optional[str], 628 label_paths: Union[List[Union[np.ndarray, torch.Tensor]], List[FilePath], FilePath], 629 label_key: Optional[str], 630 patch_shape: Tuple[int], 631 with_segmentation_decoder: bool, 632 with_channels: Optional[bool] = None, 633 train_instance_segmentation_only: bool = False, 634 sampler: Optional[Callable] = None, 635 raw_transform: Optional[Callable] = None, 636 n_samples: Optional[int] = None, 637 is_train: bool = True, 638 min_size: int = 25, 639 max_sampling_attempts: Optional[int] = None, 640 rois: Optional[Union[slice, Tuple[slice, ...]]] = None, 641 is_multi_tensor: bool = True, 642 **kwargs, 643) -> Dataset: 644 """Create a PyTorch Dataset for training a SAM model. 645 646 Args: 647 raw_paths: The path(s) to the image data used for training. 648 Can either be multiple 2D images or volumetric data. 649 The data can also be passed as a list of numpy arrays or torch tensors. 650 raw_key: The key for accessing the image data. Internal filepath for hdf5-like input 651 or a glob pattern for selecting multiple files. 652 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 653 label_paths: The path(s) to the label data used for training. 654 Can either be multiple 2D images or volumetric data. 655 The data can also be passed as a list of numpy arrays or torch tensors. 656 label_key: The key for accessing the label data. Internal filepath for hdf5-like input 657 or a glob pattern for selecting multiple files. 658 Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors. 659 patch_shape: The shape for training patches. 660 with_segmentation_decoder: Whether to train with additional segmentation decoder. 661 with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. 662 train_instance_segmentation_only: Set this argument to True in order to 663 pass the dataset to `train_instance_segmentation`. By default, set to 'False'. 664 sampler: A sampler to reject batches according to a given criterion. 665 raw_transform: Transformation applied to the image data. 666 If not given the data will be cast to 8bit. 667 n_samples: The number of samples for this dataset. 668 is_train: Whether this dataset is used for training or validation. By default, set to 'True'. 669 min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'. 670 max_sampling_attempts: Number of sampling attempts to make from a dataset. 671 rois: The region of interest(s) for the data. 672 is_multi_tensor: Whether the input data to data transforms is multiple tensors or not. 673 kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. 674 675 Returns: 676 The segmentation dataset. 677 """ 678 679 # Check if this dataset should be used for instance segmentation only training. 680 # If yes, we set return_instances to False, since the instance channel must not 681 # be passed for this training mode. 682 return_instances = True 683 if train_instance_segmentation_only: 684 if not with_segmentation_decoder: 685 raise ValueError( 686 "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." 687 ) 688 return_instances = False 689 690 # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. 691 # This is necessary, because training for interactive segmentation does not work on 'empty' images. 692 # However, if we train only the automatic instance segmentation decoder, then this sampler is not required 693 # and we do not set a default sampler. 694 if sampler is None and not train_instance_segmentation_only: 695 sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) 696 697 is_seg_dataset, with_channels = _determine_dataset_and_channels( 698 raw_paths, raw_key, label_paths, label_key, with_channels, **kwargs 699 ) 700 # Since `is_seg_dataset` is selected and sorted above, let's remove it out of running kwargs, if available. 701 kwargs.pop("is_seg_dataset", None) 702 703 # Set the data transformations. 704 if raw_transform is None: 705 raw_transform = require_8bit 706 707 # Prepare the label transform. 708 if with_segmentation_decoder: 709 default_label_transform = torch_em.transform.label.PerObjectDistanceTransform( 710 distances=True, 711 boundary_distances=True, 712 directed_distances=False, 713 foreground=True, 714 instances=return_instances, 715 min_size=min_size, 716 ) 717 else: 718 default_label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) 719 720 # Allow combining label transforms. 721 custom_label_transform = kwargs.pop("label_transform", None) 722 if custom_label_transform is None: 723 label_transform = default_label_transform 724 else: 725 label_transform = torch_em.transform.generic.Compose( 726 custom_label_transform, default_label_transform, is_multi_tensor=is_multi_tensor 727 ) 728 729 # Check the patch shape to add a singleton if required. 730 patch_shape = _update_patch_shape( 731 patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, 732 ) 733 734 # Set a minimum number of samples per epoch. 735 if n_samples is None: 736 loader = torch_em.default_segmentation_loader( 737 raw_paths=raw_paths, 738 raw_key=raw_key, 739 label_paths=label_paths, 740 label_key=label_key, 741 batch_size=1, 742 patch_shape=patch_shape, 743 with_channels=with_channels, 744 ndim=2, 745 is_seg_dataset=is_seg_dataset, 746 raw_transform=raw_transform, 747 rois=rois, 748 **kwargs 749 ) 750 n_samples = max(len(loader), 100 if is_train else 5) 751 752 dataset = torch_em.default_segmentation_dataset( 753 raw_paths=raw_paths, 754 raw_key=raw_key, 755 label_paths=label_paths, 756 label_key=label_key, 757 patch_shape=patch_shape, 758 raw_transform=raw_transform, 759 label_transform=label_transform, 760 with_channels=with_channels, 761 ndim=2, 762 sampler=sampler, 763 n_samples=n_samples, 764 is_seg_dataset=is_seg_dataset, 765 rois=rois, 766 **kwargs, 767 ) 768 769 if max_sampling_attempts is not None: 770 if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): 771 for ds in dataset.datasets: 772 ds.max_sampling_attempts = max_sampling_attempts 773 else: 774 dataset.max_sampling_attempts = max_sampling_attempts 775 776 return dataset
Create a PyTorch Dataset for training a SAM model.
Arguments:
- raw_paths: The path(s) to the image data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
- raw_key: The key for accessing the image data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
- label_paths: The path(s) to the label data used for training. Can either be multiple 2D images or volumetric data. The data can also be passed as a list of numpy arrays or torch tensors.
- label_key: The key for accessing the label data. Internal filepath for hdf5-like input or a glob pattern for selecting multiple files. Set to None when passing a list of file paths to regular images or numpy arrays / torch tensors.
- patch_shape: The shape for training patches.
- with_segmentation_decoder: Whether to train with additional segmentation decoder.
- with_channels: Whether the image data has channels. By default, it makes the decision based on inputs.
- train_instance_segmentation_only: Set this argument to True in order to
pass the dataset to
train_instance_segmentation. By default, set to 'False'. - sampler: A sampler to reject batches according to a given criterion.
- raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit.
- n_samples: The number of samples for this dataset.
- is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
- min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
- max_sampling_attempts: Number of sampling attempts to make from a dataset.
- rois: The region of interest(s) for the data.
- is_multi_tensor: Whether the input data to data transforms is multiple tensors or not.
- kwargs: Additional keyword arguments for
torch_em.default_segmentation_dataset.
Returns:
The segmentation dataset.
779def default_sam_loader(**kwargs) -> DataLoader: 780 """Create a PyTorch DataLoader for training a SAM model. 781 782 Args: 783 kwargs: Keyword arguments for `micro_sam.training.default_sam_dataset` or for the PyTorch DataLoader. 784 785 Returns: 786 The DataLoader. 787 """ 788 sam_ds_kwargs, extra_kwargs = split_kwargs(default_sam_dataset, **kwargs) 789 790 # There might be additional parameters supported by `torch_em.default_segmentation_dataset`, 791 # which the users can provide to get their desired segmentation dataset. 792 extra_ds_kwargs, loader_kwargs = split_kwargs(torch_em.default_segmentation_dataset, **extra_kwargs) 793 ds_kwargs = {**sam_ds_kwargs, **extra_ds_kwargs} 794 795 ds = default_sam_dataset(**ds_kwargs) 796 return torch_em.segmentation.get_data_loader(ds, **loader_kwargs)
Create a PyTorch DataLoader for training a SAM model.
Arguments:
- kwargs: Keyword arguments for
micro_sam.training.default_sam_datasetor for the PyTorch DataLoader.
Returns:
The DataLoader.
Best training configurations for given hardware resources.
838def train_sam_for_configuration( 839 name: str, 840 train_loader: DataLoader, 841 val_loader: DataLoader, 842 configuration: Optional[str] = None, 843 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 844 with_segmentation_decoder: bool = True, 845 train_instance_segmentation_only: bool = False, 846 model_type: Optional[str] = None, 847 **kwargs, 848) -> None: 849 """Run training for a SAM model with the configuration for a given hardware resource. 850 851 Selects the best training settings for the given configuration. 852 The available configurations are listed in `CONFIGURATIONS`. 853 854 Args: 855 name: The name of the model to be trained. The checkpoint and logs folder will have this name. 856 train_loader: The dataloader for training. 857 val_loader: The dataloader for validation. 858 configuration: The configuration (= name of hardware resource). 859 By default, it is automatically selected for the best VRAM combination. 860 checkpoint_path: Path to checkpoint for initializing the SAM model. 861 with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. 862 By default, trains with the additional instance segmentation decoder. 863 train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation 864 using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. 865 model_type: Over-ride the default model type. 866 This can be used to use one of the micro_sam models as starting point 867 instead of a default sam model. 868 kwargs: Additional keyword parameters that will be passed to `train_sam`. 869 """ 870 if configuration is None: # Automatically choose based on available VRAM combination. 871 configuration = _find_best_configuration() 872 873 if configuration in CONFIGURATIONS: 874 train_kwargs = CONFIGURATIONS[configuration] 875 else: 876 raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}") 877 878 if model_type is None: 879 model_type = train_kwargs.pop("model_type") 880 else: 881 expected_model_type = train_kwargs.pop("model_type") 882 if model_type[:5] != expected_model_type: 883 warnings.warn("You have specified a different model type.") 884 885 train_kwargs.update(**kwargs) 886 if train_instance_segmentation_only: 887 instance_seg_kwargs, extra_kwargs = split_kwargs(train_instance_segmentation, **train_kwargs) 888 model_kwargs, extra_kwargs = split_kwargs(get_sam_model, **extra_kwargs) 889 instance_seg_kwargs.update(**model_kwargs) 890 891 train_instance_segmentation( 892 name=name, 893 train_loader=train_loader, 894 val_loader=val_loader, 895 checkpoint_path=checkpoint_path, 896 model_type=model_type, 897 **instance_seg_kwargs, 898 ) 899 else: 900 train_sam( 901 name=name, 902 train_loader=train_loader, 903 val_loader=val_loader, 904 checkpoint_path=checkpoint_path, 905 with_segmentation_decoder=with_segmentation_decoder, 906 model_type=model_type, 907 **train_kwargs 908 )
Run training for a SAM model with the configuration for a given hardware resource.
Selects the best training settings for the given configuration.
The available configurations are listed in CONFIGURATIONS.
Arguments:
- name: The name of the model to be trained. The checkpoint and logs folder will have this name.
- train_loader: The dataloader for training.
- val_loader: The dataloader for validation.
- configuration: The configuration (= name of hardware resource). By default, it is automatically selected for the best VRAM combination.
- checkpoint_path: Path to checkpoint for initializing the SAM model.
- with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. By default, trains with the additional instance segmentation decoder.
- train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation
using the training implementation
train_instance_segmentation. By default,train_samis used. - model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model.
- kwargs: Additional keyword parameters that will be passed to
train_sam.