micro_sam.training.joint_sam_trainer
1import os 2import time 3import numpy as np 4from collections import OrderedDict 5 6import torch 7from torch.utils.tensorboard import SummaryWriter 8from torchvision.utils import make_grid 9 10from .sam_trainer import SamTrainer 11 12from torch_em.trainer.logger_base import TorchEmLogger 13from torch_em.trainer.tensorboard_logger import normalize_im 14 15 16class JointSamTrainer(SamTrainer): 17 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 18 19 This class is inherited from `SamTrainer`. 20 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 21 for details on its implementation. 22 23 Args: 24 unetr: The UNet-style model with vision transformer as the image encoder. 25 Required to perform automatic instance segmentation. 26 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 27 instance_metric: The metric to compare the predictions and the targets. 28 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 29 """ 30 31 def __init__( 32 self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs 33 ): 34 super().__init__(**kwargs) 35 self.unetr = unetr 36 self.instance_loss = instance_loss 37 self.instance_metric = instance_metric 38 39 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 40 current_unetr_state = self.unetr.state_dict() 41 decoder_state = [] 42 for k, v in current_unetr_state.items(): 43 if not k.startswith("encoder"): 44 decoder_state.append((k, v)) 45 decoder_state = OrderedDict(decoder_state) 46 47 super().save_checkpoint( 48 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 49 ) 50 51 def load_checkpoint(self, checkpoint="best"): 52 save_dict = super().load_checkpoint(checkpoint) 53 54 # let's get the image encoder params from sam 55 sam_state = save_dict["model_state"] 56 encoder_state = [] 57 prune_prefix = "sam.image_" 58 for k, v in sam_state.items(): 59 if k.startswith(prune_prefix): 60 encoder_state.append((k[len(prune_prefix):], v)) 61 encoder_state = OrderedDict(encoder_state) 62 63 # let's get the decoder params from unetr 64 decoder_state = save_dict["decoder_state"] 65 66 # now let's merge the two to get the params for the unetr 67 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 68 69 self.unetr.load_state_dict(unetr_state) 70 self.unetr.to(self.device) 71 return save_dict 72 73 def _instance_iteration(self, x, y, metric_for_val=False): 74 """Perform the segmentation of distance maps and 75 compute the loss (and metric) between the prediction and target. 76 """ 77 outputs = self.unetr(x.to(self.device)) 78 loss = self.instance_loss(outputs, y.to(self.device)) 79 if metric_for_val: 80 metric = self.instance_metric(outputs, y.to(self.device)) 81 return loss, metric 82 else: 83 return loss 84 85 def _train_epoch_impl(self, progress, forward_context, backprop): 86 self.model.train() 87 88 input_check_done = False 89 90 n_iter = 0 91 t_per_iter = time.time() 92 for x, y in self.train_loader: 93 labels_instances = y[:, 0, ...].unsqueeze(1) 94 labels_for_unetr = y[:, 1:, ...] 95 96 input_check_done = self._check_input_normalization(x, input_check_done) 97 98 self.optimizer.zero_grad() 99 100 with forward_context(): 101 # 1. train for the interactive segmentation 102 (loss, mask_loss, iou_regression_loss, model_iou, 103 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 104 105 backprop(loss) 106 107 self.optimizer.zero_grad() 108 109 with forward_context(): 110 # 2. train for the automatic instance segmentation 111 unetr_loss = self._instance_iteration(x, labels_for_unetr) 112 113 backprop(unetr_loss) 114 115 if self.logger is not None: 116 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 117 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 118 self.logger.log_train( 119 self._iteration, loss, lr, x, labels_instances, samples, 120 mask_loss, iou_regression_loss, model_iou, unetr_loss 121 ) 122 123 self._iteration += 1 124 n_iter += 1 125 if self._iteration >= self.max_iteration: 126 break 127 progress.update(1) 128 129 t_per_iter = (time.time() - t_per_iter) / n_iter 130 return t_per_iter 131 132 def _validate_impl(self, forward_context): 133 self.model.eval() 134 135 input_check_done = False 136 137 val_iteration = 0 138 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 139 140 with torch.no_grad(): 141 for x, y in self.val_loader: 142 labels_instances = y[:, 0, ...].unsqueeze(1) 143 labels_for_unetr = y[:, 1:, ...] 144 145 input_check_done = self._check_input_normalization(x, input_check_done) 146 147 with forward_context(): 148 # 1. validate for the interactive segmentation 149 (loss, mask_loss, iou_regression_loss, model_iou, 150 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 151 152 with forward_context(): 153 # 2. validate for the automatic instance segmentation 154 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 155 156 loss_val += loss.item() 157 metric_val += metric.item() + (unetr_metric.item() / 3) 158 model_iou_val += model_iou.item() 159 val_iteration += 1 160 161 loss_val /= len(self.val_loader) 162 metric_val /= len(self.val_loader) 163 model_iou_val /= len(self.val_loader) 164 165 if self.logger is not None: 166 self.logger.log_validation( 167 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 168 mask_loss, iou_regression_loss, model_iou_val, unetr_loss 169 ) 170 171 return metric_val 172 173 174class JointSamLogger(TorchEmLogger): 175 """@private""" 176 def __init__(self, trainer, save_root, **unused_kwargs): 177 super().__init__(trainer, save_root) 178 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 179 os.path.join(save_root, "logs", trainer.name) 180 os.makedirs(self.log_dir, exist_ok=True) 181 182 self.tb = SummaryWriter(self.log_dir) 183 self.log_image_interval = trainer.log_image_interval 184 185 def add_image(self, x, y, samples, name, step): 186 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 187 188 image = normalize_im(x[selection].cpu()) 189 190 self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) 191 self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) 192 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 193 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 194 195 def log_train( 196 self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 197 ): 198 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 199 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 200 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 201 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 202 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 203 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 204 if step % self.log_image_interval == 0: 205 self.add_image(x, y, samples, "train", step) 206 207 def log_validation( 208 self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 209 ): 210 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 211 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 212 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 213 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 214 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 215 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 216 self.add_image(x, y, samples, "validation", step)
17class JointSamTrainer(SamTrainer): 18 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 19 20 This class is inherited from `SamTrainer`. 21 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 22 for details on its implementation. 23 24 Args: 25 unetr: The UNet-style model with vision transformer as the image encoder. 26 Required to perform automatic instance segmentation. 27 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 28 instance_metric: The metric to compare the predictions and the targets. 29 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 30 """ 31 32 def __init__( 33 self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs 34 ): 35 super().__init__(**kwargs) 36 self.unetr = unetr 37 self.instance_loss = instance_loss 38 self.instance_metric = instance_metric 39 40 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 41 current_unetr_state = self.unetr.state_dict() 42 decoder_state = [] 43 for k, v in current_unetr_state.items(): 44 if not k.startswith("encoder"): 45 decoder_state.append((k, v)) 46 decoder_state = OrderedDict(decoder_state) 47 48 super().save_checkpoint( 49 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 50 ) 51 52 def load_checkpoint(self, checkpoint="best"): 53 save_dict = super().load_checkpoint(checkpoint) 54 55 # let's get the image encoder params from sam 56 sam_state = save_dict["model_state"] 57 encoder_state = [] 58 prune_prefix = "sam.image_" 59 for k, v in sam_state.items(): 60 if k.startswith(prune_prefix): 61 encoder_state.append((k[len(prune_prefix):], v)) 62 encoder_state = OrderedDict(encoder_state) 63 64 # let's get the decoder params from unetr 65 decoder_state = save_dict["decoder_state"] 66 67 # now let's merge the two to get the params for the unetr 68 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 69 70 self.unetr.load_state_dict(unetr_state) 71 self.unetr.to(self.device) 72 return save_dict 73 74 def _instance_iteration(self, x, y, metric_for_val=False): 75 """Perform the segmentation of distance maps and 76 compute the loss (and metric) between the prediction and target. 77 """ 78 outputs = self.unetr(x.to(self.device)) 79 loss = self.instance_loss(outputs, y.to(self.device)) 80 if metric_for_val: 81 metric = self.instance_metric(outputs, y.to(self.device)) 82 return loss, metric 83 else: 84 return loss 85 86 def _train_epoch_impl(self, progress, forward_context, backprop): 87 self.model.train() 88 89 input_check_done = False 90 91 n_iter = 0 92 t_per_iter = time.time() 93 for x, y in self.train_loader: 94 labels_instances = y[:, 0, ...].unsqueeze(1) 95 labels_for_unetr = y[:, 1:, ...] 96 97 input_check_done = self._check_input_normalization(x, input_check_done) 98 99 self.optimizer.zero_grad() 100 101 with forward_context(): 102 # 1. train for the interactive segmentation 103 (loss, mask_loss, iou_regression_loss, model_iou, 104 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 105 106 backprop(loss) 107 108 self.optimizer.zero_grad() 109 110 with forward_context(): 111 # 2. train for the automatic instance segmentation 112 unetr_loss = self._instance_iteration(x, labels_for_unetr) 113 114 backprop(unetr_loss) 115 116 if self.logger is not None: 117 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 118 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 119 self.logger.log_train( 120 self._iteration, loss, lr, x, labels_instances, samples, 121 mask_loss, iou_regression_loss, model_iou, unetr_loss 122 ) 123 124 self._iteration += 1 125 n_iter += 1 126 if self._iteration >= self.max_iteration: 127 break 128 progress.update(1) 129 130 t_per_iter = (time.time() - t_per_iter) / n_iter 131 return t_per_iter 132 133 def _validate_impl(self, forward_context): 134 self.model.eval() 135 136 input_check_done = False 137 138 val_iteration = 0 139 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 140 141 with torch.no_grad(): 142 for x, y in self.val_loader: 143 labels_instances = y[:, 0, ...].unsqueeze(1) 144 labels_for_unetr = y[:, 1:, ...] 145 146 input_check_done = self._check_input_normalization(x, input_check_done) 147 148 with forward_context(): 149 # 1. validate for the interactive segmentation 150 (loss, mask_loss, iou_regression_loss, model_iou, 151 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 152 153 with forward_context(): 154 # 2. validate for the automatic instance segmentation 155 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 156 157 loss_val += loss.item() 158 metric_val += metric.item() + (unetr_metric.item() / 3) 159 model_iou_val += model_iou.item() 160 val_iteration += 1 161 162 loss_val /= len(self.val_loader) 163 metric_val /= len(self.val_loader) 164 model_iou_val /= len(self.val_loader) 165 166 if self.logger is not None: 167 self.logger.log_validation( 168 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 169 mask_loss, iou_regression_loss, model_iou_val, unetr_loss 170 ) 171 172 return metric_val
Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
This class is inherited from SamTrainer
.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- unetr: The UNet-style model with vision transformer as the image encoder. Required to perform automatic instance segmentation.
- instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
- instance_metric: The metric to compare the predictions and the targets.
- kwargs: The keyword arguments of the
SamTrainer
(andDefaultTrainer
) class.
JointSamTrainer( unetr: torch.nn.modules.module.Module, instance_loss: torch.nn.modules.module.Module, instance_metric: torch.nn.modules.module.Module, **kwargs)
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit