micro_sam.training.joint_sam_trainer
1import os 2import time 3import numpy as np 4from collections import OrderedDict 5 6import torch 7from torch.utils.tensorboard import SummaryWriter 8from torchvision.utils import make_grid 9 10from .sam_trainer import SamTrainer 11 12from torch_em.trainer.logger_base import TorchEmLogger 13from torch_em.trainer.tensorboard_logger import normalize_im 14 15 16class JointSamTrainer(SamTrainer): 17 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 18 19 This class is inherited from `SamTrainer`. 20 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 21 for details on its implementation. 22 23 Args: 24 unetr: The UNet-style model with vision transformer as the image encoder. 25 Required to perform automatic instance segmentation. 26 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 27 instance_metric: The metric to compare the predictions and the targets. 28 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 29 """ 30 def __init__( 31 self, 32 unetr: torch.nn.Module, 33 instance_loss: torch.nn.Module, 34 instance_metric: torch.nn.Module, 35 **kwargs 36 ): 37 super().__init__(**kwargs) 38 self.unetr = unetr 39 self.instance_loss = instance_loss 40 self.instance_metric = instance_metric 41 42 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 43 current_unetr_state = self.unetr.state_dict() 44 decoder_state = [] 45 for k, v in current_unetr_state.items(): 46 if not k.startswith("encoder"): 47 decoder_state.append((k, v)) 48 decoder_state = OrderedDict(decoder_state) 49 50 super().save_checkpoint( 51 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 52 ) 53 54 def load_checkpoint(self, checkpoint="best"): 55 save_dict = super().load_checkpoint(checkpoint) 56 57 # let's get the image encoder params from sam 58 sam_state = save_dict["model_state"] 59 encoder_state = [] 60 prune_prefix = "sam.image_" 61 for k, v in sam_state.items(): 62 if k.startswith(prune_prefix): 63 encoder_state.append((k[len(prune_prefix):], v)) 64 encoder_state = OrderedDict(encoder_state) 65 66 # let's get the decoder params from unetr 67 decoder_state = save_dict["decoder_state"] 68 69 # now let's merge the two to get the params for the unetr 70 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 71 72 self.unetr.load_state_dict(unetr_state) 73 self.unetr.to(self.device) 74 return save_dict 75 76 def _instance_iteration(self, x, y, metric_for_val=False): 77 """Perform the segmentation of distance maps and 78 compute the loss (and metric) between the prediction and target. 79 """ 80 outputs = self.unetr(x.to(self.device)) 81 loss = self.instance_loss(outputs, y.to(self.device)) 82 if metric_for_val: 83 metric = self.instance_metric(outputs, y.to(self.device)) 84 return loss, metric 85 else: 86 return loss 87 88 def _train_epoch_impl(self, progress, forward_context, backprop): 89 self.model.train() 90 91 input_check_done = False 92 93 n_iter = 0 94 t_per_iter = time.time() 95 for x, y in self.train_loader: 96 labels_instances = y[:, 0, ...].unsqueeze(1) 97 labels_for_unetr = y[:, 1:, ...] 98 99 input_check_done = self._check_input_normalization(x, input_check_done) 100 101 self.optimizer.zero_grad() 102 103 with forward_context(): 104 # 1. train for the interactive segmentation 105 (loss, mask_loss, iou_regression_loss, model_iou, 106 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 107 108 backprop(loss) 109 110 self.optimizer.zero_grad() 111 112 with forward_context(): 113 # 2. train for the automatic instance segmentation 114 unetr_loss = self._instance_iteration(x, labels_for_unetr) 115 116 backprop(unetr_loss) 117 118 if self.logger is not None: 119 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 120 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 121 self.logger.log_train( 122 self._iteration, loss, lr, x, labels_instances, samples, 123 mask_loss, iou_regression_loss, model_iou, unetr_loss 124 ) 125 126 self._iteration += 1 127 n_iter += 1 128 if self._iteration >= self.max_iteration: 129 break 130 progress.update(1) 131 132 t_per_iter = (time.time() - t_per_iter) / n_iter 133 return t_per_iter 134 135 def _validate_impl(self, forward_context): 136 self.model.eval() 137 138 input_check_done = False 139 140 val_iteration = 0 141 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 142 143 with torch.no_grad(): 144 for x, y in self.val_loader: 145 labels_instances = y[:, 0, ...].unsqueeze(1) 146 labels_for_unetr = y[:, 1:, ...] 147 148 input_check_done = self._check_input_normalization(x, input_check_done) 149 150 with forward_context(): 151 # 1. validate for the interactive segmentation 152 (loss, mask_loss, iou_regression_loss, model_iou, 153 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 154 155 with forward_context(): 156 # 2. validate for the automatic instance segmentation 157 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 158 159 loss_val += loss.item() 160 metric_val += metric.item() + (unetr_metric.item() / 3) 161 model_iou_val += model_iou.item() 162 val_iteration += 1 163 164 loss_val /= len(self.val_loader) 165 metric_val /= len(self.val_loader) 166 model_iou_val /= len(self.val_loader) 167 168 if self.logger is not None: 169 self.logger.log_validation( 170 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 171 mask_loss, iou_regression_loss, model_iou_val, unetr_loss 172 ) 173 174 return metric_val 175 176 177class JointSamLogger(TorchEmLogger): 178 """@private""" 179 def __init__(self, trainer, save_root, **unused_kwargs): 180 super().__init__(trainer, save_root) 181 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 182 os.path.join(save_root, "logs", trainer.name) 183 os.makedirs(self.log_dir, exist_ok=True) 184 185 self.tb = SummaryWriter(self.log_dir) 186 self.log_image_interval = trainer.log_image_interval 187 188 def add_image(self, x, y, samples, name, step): 189 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 190 191 image = normalize_im(x[selection].cpu()) 192 193 self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) 194 self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) 195 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 196 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 197 198 def log_train( 199 self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 200 ): 201 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 202 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 203 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 204 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 205 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 206 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 207 if step % self.log_image_interval == 0: 208 self.add_image(x, y, samples, "train", step) 209 210 def log_validation( 211 self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 212 ): 213 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 214 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 215 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 216 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 217 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 218 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 219 self.add_image(x, y, samples, "validation", step)
17class JointSamTrainer(SamTrainer): 18 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 19 20 This class is inherited from `SamTrainer`. 21 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 22 for details on its implementation. 23 24 Args: 25 unetr: The UNet-style model with vision transformer as the image encoder. 26 Required to perform automatic instance segmentation. 27 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 28 instance_metric: The metric to compare the predictions and the targets. 29 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 30 """ 31 def __init__( 32 self, 33 unetr: torch.nn.Module, 34 instance_loss: torch.nn.Module, 35 instance_metric: torch.nn.Module, 36 **kwargs 37 ): 38 super().__init__(**kwargs) 39 self.unetr = unetr 40 self.instance_loss = instance_loss 41 self.instance_metric = instance_metric 42 43 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 44 current_unetr_state = self.unetr.state_dict() 45 decoder_state = [] 46 for k, v in current_unetr_state.items(): 47 if not k.startswith("encoder"): 48 decoder_state.append((k, v)) 49 decoder_state = OrderedDict(decoder_state) 50 51 super().save_checkpoint( 52 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 53 ) 54 55 def load_checkpoint(self, checkpoint="best"): 56 save_dict = super().load_checkpoint(checkpoint) 57 58 # let's get the image encoder params from sam 59 sam_state = save_dict["model_state"] 60 encoder_state = [] 61 prune_prefix = "sam.image_" 62 for k, v in sam_state.items(): 63 if k.startswith(prune_prefix): 64 encoder_state.append((k[len(prune_prefix):], v)) 65 encoder_state = OrderedDict(encoder_state) 66 67 # let's get the decoder params from unetr 68 decoder_state = save_dict["decoder_state"] 69 70 # now let's merge the two to get the params for the unetr 71 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 72 73 self.unetr.load_state_dict(unetr_state) 74 self.unetr.to(self.device) 75 return save_dict 76 77 def _instance_iteration(self, x, y, metric_for_val=False): 78 """Perform the segmentation of distance maps and 79 compute the loss (and metric) between the prediction and target. 80 """ 81 outputs = self.unetr(x.to(self.device)) 82 loss = self.instance_loss(outputs, y.to(self.device)) 83 if metric_for_val: 84 metric = self.instance_metric(outputs, y.to(self.device)) 85 return loss, metric 86 else: 87 return loss 88 89 def _train_epoch_impl(self, progress, forward_context, backprop): 90 self.model.train() 91 92 input_check_done = False 93 94 n_iter = 0 95 t_per_iter = time.time() 96 for x, y in self.train_loader: 97 labels_instances = y[:, 0, ...].unsqueeze(1) 98 labels_for_unetr = y[:, 1:, ...] 99 100 input_check_done = self._check_input_normalization(x, input_check_done) 101 102 self.optimizer.zero_grad() 103 104 with forward_context(): 105 # 1. train for the interactive segmentation 106 (loss, mask_loss, iou_regression_loss, model_iou, 107 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 108 109 backprop(loss) 110 111 self.optimizer.zero_grad() 112 113 with forward_context(): 114 # 2. train for the automatic instance segmentation 115 unetr_loss = self._instance_iteration(x, labels_for_unetr) 116 117 backprop(unetr_loss) 118 119 if self.logger is not None: 120 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 121 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 122 self.logger.log_train( 123 self._iteration, loss, lr, x, labels_instances, samples, 124 mask_loss, iou_regression_loss, model_iou, unetr_loss 125 ) 126 127 self._iteration += 1 128 n_iter += 1 129 if self._iteration >= self.max_iteration: 130 break 131 progress.update(1) 132 133 t_per_iter = (time.time() - t_per_iter) / n_iter 134 return t_per_iter 135 136 def _validate_impl(self, forward_context): 137 self.model.eval() 138 139 input_check_done = False 140 141 val_iteration = 0 142 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 143 144 with torch.no_grad(): 145 for x, y in self.val_loader: 146 labels_instances = y[:, 0, ...].unsqueeze(1) 147 labels_for_unetr = y[:, 1:, ...] 148 149 input_check_done = self._check_input_normalization(x, input_check_done) 150 151 with forward_context(): 152 # 1. validate for the interactive segmentation 153 (loss, mask_loss, iou_regression_loss, model_iou, 154 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 155 156 with forward_context(): 157 # 2. validate for the automatic instance segmentation 158 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 159 160 loss_val += loss.item() 161 metric_val += metric.item() + (unetr_metric.item() / 3) 162 model_iou_val += model_iou.item() 163 val_iteration += 1 164 165 loss_val /= len(self.val_loader) 166 metric_val /= len(self.val_loader) 167 model_iou_val /= len(self.val_loader) 168 169 if self.logger is not None: 170 self.logger.log_validation( 171 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 172 mask_loss, iou_regression_loss, model_iou_val, unetr_loss 173 ) 174 175 return metric_val
Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
This class is inherited from SamTrainer
.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- unetr: The UNet-style model with vision transformer as the image encoder. Required to perform automatic instance segmentation.
- instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
- instance_metric: The metric to compare the predictions and the targets.
- kwargs: The keyword arguments of the
SamTrainer
(andDefaultTrainer
) class.
JointSamTrainer( unetr: torch.nn.modules.module.Module, instance_loss: torch.nn.modules.module.Module, instance_metric: torch.nn.modules.module.Module, **kwargs)
def
save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
43 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 44 current_unetr_state = self.unetr.state_dict() 45 decoder_state = [] 46 for k, v in current_unetr_state.items(): 47 if not k.startswith("encoder"): 48 decoder_state.append((k, v)) 49 decoder_state = OrderedDict(decoder_state) 50 51 super().save_checkpoint( 52 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 53 )
def
load_checkpoint(self, checkpoint='best'):
55 def load_checkpoint(self, checkpoint="best"): 56 save_dict = super().load_checkpoint(checkpoint) 57 58 # let's get the image encoder params from sam 59 sam_state = save_dict["model_state"] 60 encoder_state = [] 61 prune_prefix = "sam.image_" 62 for k, v in sam_state.items(): 63 if k.startswith(prune_prefix): 64 encoder_state.append((k[len(prune_prefix):], v)) 65 encoder_state = OrderedDict(encoder_state) 66 67 # let's get the decoder params from unetr 68 decoder_state = save_dict["decoder_state"] 69 70 # now let's merge the two to get the params for the unetr 71 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 72 73 self.unetr.load_state_dict(unetr_state) 74 self.unetr.to(self.device) 75 return save_dict
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- fit