micro_sam.training.joint_sam_trainer
1import os 2import time 3import numpy as np 4from collections import OrderedDict 5 6import torch 7from torch.utils.tensorboard import SummaryWriter 8from torchvision.utils import make_grid 9 10from .sam_trainer import SamTrainer 11 12from torch_em.trainer.logger_base import TorchEmLogger 13from torch_em.trainer.tensorboard_logger import normalize_im 14 15 16class JointSamTrainer(SamTrainer): 17 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 18 19 This class is inherited from `SamTrainer`. 20 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 21 for details on its implementation. 22 23 Args: 24 unetr: The UNet-style model with vision transformer as the image encoder. 25 Required to perform automatic instance segmentation. 26 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 27 instance_metric: The metric to compare the predictions and the targets. 28 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 29 """ 30 31 def __init__( 32 self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs 33 ): 34 super().__init__(**kwargs) 35 self.unetr = unetr 36 self.instance_loss = instance_loss 37 self.instance_metric = instance_metric 38 39 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 40 current_unetr_state = self.unetr.state_dict() 41 decoder_state = [] 42 for k, v in current_unetr_state.items(): 43 if not k.startswith("encoder"): 44 decoder_state.append((k, v)) 45 decoder_state = OrderedDict(decoder_state) 46 47 super().save_checkpoint( 48 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 49 ) 50 51 def load_checkpoint(self, checkpoint="best"): 52 save_dict = super().load_checkpoint(checkpoint) 53 54 # let's get the image encoder params from sam 55 sam_state = save_dict["model_state"] 56 encoder_state = [] 57 prune_prefix = "sam.image_" 58 for k, v in sam_state.items(): 59 if k.startswith(prune_prefix): 60 encoder_state.append((k[len(prune_prefix):], v)) 61 encoder_state = OrderedDict(encoder_state) 62 63 # let's get the decoder params from unetr 64 decoder_state = save_dict["decoder_state"] 65 66 # now let's merge the two to get the params for the unetr 67 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 68 69 self.unetr.load_state_dict(unetr_state) 70 self.unetr.to(self.device) 71 return save_dict 72 73 def _instance_iteration(self, x, y, metric_for_val=False): 74 """Perform the segmentation of distance maps and 75 compute the loss (and metric) between the prediction and target. 76 """ 77 outputs = self.unetr(x.to(self.device)) 78 loss = self.instance_loss(outputs, y.to(self.device)) 79 if metric_for_val: 80 metric = self.instance_metric(outputs, y.to(self.device)) 81 return loss, metric 82 else: 83 return loss 84 85 def _train_epoch_impl(self, progress, forward_context, backprop): 86 self.model.train() 87 88 input_check_done = False 89 90 n_iter = 0 91 t_per_iter = time.time() 92 for x, y in self.train_loader: 93 labels_instances = y[:, 0, ...].unsqueeze(1) 94 labels_for_unetr = y[:, 1:, ...] 95 96 input_check_done = self._check_input_normalization(x, input_check_done) 97 98 self.optimizer.zero_grad() 99 100 with forward_context(): 101 # 1. train for the interactive segmentation 102 (loss, mask_loss, iou_regression_loss, model_iou, 103 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 104 105 backprop(loss) 106 107 self.optimizer.zero_grad() 108 109 with forward_context(): 110 # 2. train for the automatic instance segmentation 111 unetr_loss = self._instance_iteration(x, labels_for_unetr) 112 113 backprop(unetr_loss) 114 115 if self.logger is not None: 116 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 117 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 118 self.logger.log_train( 119 self._iteration, loss, lr, x, labels_instances, samples, 120 mask_loss, iou_regression_loss, model_iou, unetr_loss 121 ) 122 123 self._iteration += 1 124 n_iter += 1 125 if self._iteration >= self.max_iteration: 126 break 127 progress.update(1) 128 129 t_per_iter = (time.time() - t_per_iter) / n_iter 130 return t_per_iter 131 132 def _validate_impl(self, forward_context): 133 self.model.eval() 134 135 input_check_done = False 136 137 val_iteration = 0 138 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 139 mask_loss_val, iou_loss_val, unetr_loss_val = 0.0, 0.0, 0.0 140 141 with torch.no_grad(): 142 for x, y in self.val_loader: 143 labels_instances = y[:, 0, ...].unsqueeze(1) 144 labels_for_unetr = y[:, 1:, ...] 145 146 input_check_done = self._check_input_normalization(x, input_check_done) 147 148 with forward_context(): 149 # 1. validate for the interactive segmentation 150 (loss, mask_loss, iou_regression_loss, model_iou, 151 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 152 153 with forward_context(): 154 # 2. validate for the automatic instance segmentation 155 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 156 157 loss_val += loss.item() 158 metric_val += metric.item() + (unetr_metric.item() / 3) 159 mask_loss_val += mask_loss.item() 160 iou_loss_val += iou_regression_loss.item() 161 model_iou_val += model_iou.item() 162 unetr_loss_val += unetr_loss.item() 163 val_iteration += 1 164 165 loss_val /= len(self.val_loader) 166 metric_val /= len(self.val_loader) 167 mask_loss_val /= len(self.val_loader) 168 iou_loss_val /= len(self.val_loader) 169 model_iou_val /= len(self.val_loader) 170 unetr_loss_val /= len(self.val_loader) 171 172 if self.logger is not None: 173 self.logger.log_validation( 174 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 175 mask_loss_val, iou_loss_val, model_iou_val, unetr_loss_val 176 ) 177 178 return metric_val 179 180 181class JointSamLogger(TorchEmLogger): 182 """@private""" 183 def __init__(self, trainer, save_root, **unused_kwargs): 184 super().__init__(trainer, save_root) 185 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 186 os.path.join(save_root, "logs", trainer.name) 187 os.makedirs(self.log_dir, exist_ok=True) 188 189 self.tb = SummaryWriter(self.log_dir) 190 self.log_image_interval = trainer.log_image_interval 191 192 def add_image(self, x, y, samples, name, step): 193 selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] 194 195 image = normalize_im(x[selection].cpu()) 196 197 self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) 198 self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step) 199 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 200 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 201 202 def log_train( 203 self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 204 ): 205 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 206 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 207 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 208 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 209 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 210 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 211 if step % self.log_image_interval == 0: 212 self.add_image(x, y, samples, "train", step) 213 214 def log_validation( 215 self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss 216 ): 217 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 218 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 219 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 220 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 221 self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step) 222 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 223 self.add_image(x, y, samples, "validation", step)
17class JointSamTrainer(SamTrainer): 18 """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. 19 20 This class is inherited from `SamTrainer`. 21 Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py 22 for details on its implementation. 23 24 Args: 25 unetr: The UNet-style model with vision transformer as the image encoder. 26 Required to perform automatic instance segmentation. 27 instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. 28 instance_metric: The metric to compare the predictions and the targets. 29 kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. 30 """ 31 32 def __init__( 33 self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs 34 ): 35 super().__init__(**kwargs) 36 self.unetr = unetr 37 self.instance_loss = instance_loss 38 self.instance_metric = instance_metric 39 40 def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict): 41 current_unetr_state = self.unetr.state_dict() 42 decoder_state = [] 43 for k, v in current_unetr_state.items(): 44 if not k.startswith("encoder"): 45 decoder_state.append((k, v)) 46 decoder_state = OrderedDict(decoder_state) 47 48 super().save_checkpoint( 49 name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict 50 ) 51 52 def load_checkpoint(self, checkpoint="best"): 53 save_dict = super().load_checkpoint(checkpoint) 54 55 # let's get the image encoder params from sam 56 sam_state = save_dict["model_state"] 57 encoder_state = [] 58 prune_prefix = "sam.image_" 59 for k, v in sam_state.items(): 60 if k.startswith(prune_prefix): 61 encoder_state.append((k[len(prune_prefix):], v)) 62 encoder_state = OrderedDict(encoder_state) 63 64 # let's get the decoder params from unetr 65 decoder_state = save_dict["decoder_state"] 66 67 # now let's merge the two to get the params for the unetr 68 unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items())) 69 70 self.unetr.load_state_dict(unetr_state) 71 self.unetr.to(self.device) 72 return save_dict 73 74 def _instance_iteration(self, x, y, metric_for_val=False): 75 """Perform the segmentation of distance maps and 76 compute the loss (and metric) between the prediction and target. 77 """ 78 outputs = self.unetr(x.to(self.device)) 79 loss = self.instance_loss(outputs, y.to(self.device)) 80 if metric_for_val: 81 metric = self.instance_metric(outputs, y.to(self.device)) 82 return loss, metric 83 else: 84 return loss 85 86 def _train_epoch_impl(self, progress, forward_context, backprop): 87 self.model.train() 88 89 input_check_done = False 90 91 n_iter = 0 92 t_per_iter = time.time() 93 for x, y in self.train_loader: 94 labels_instances = y[:, 0, ...].unsqueeze(1) 95 labels_for_unetr = y[:, 1:, ...] 96 97 input_check_done = self._check_input_normalization(x, input_check_done) 98 99 self.optimizer.zero_grad() 100 101 with forward_context(): 102 # 1. train for the interactive segmentation 103 (loss, mask_loss, iou_regression_loss, model_iou, 104 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances) 105 106 backprop(loss) 107 108 self.optimizer.zero_grad() 109 110 with forward_context(): 111 # 2. train for the automatic instance segmentation 112 unetr_loss = self._instance_iteration(x, labels_for_unetr) 113 114 backprop(unetr_loss) 115 116 if self.logger is not None: 117 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 118 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 119 self.logger.log_train( 120 self._iteration, loss, lr, x, labels_instances, samples, 121 mask_loss, iou_regression_loss, model_iou, unetr_loss 122 ) 123 124 self._iteration += 1 125 n_iter += 1 126 if self._iteration >= self.max_iteration: 127 break 128 progress.update(1) 129 130 t_per_iter = (time.time() - t_per_iter) / n_iter 131 return t_per_iter 132 133 def _validate_impl(self, forward_context): 134 self.model.eval() 135 136 input_check_done = False 137 138 val_iteration = 0 139 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 140 mask_loss_val, iou_loss_val, unetr_loss_val = 0.0, 0.0, 0.0 141 142 with torch.no_grad(): 143 for x, y in self.val_loader: 144 labels_instances = y[:, 0, ...].unsqueeze(1) 145 labels_for_unetr = y[:, 1:, ...] 146 147 input_check_done = self._check_input_normalization(x, input_check_done) 148 149 with forward_context(): 150 # 1. validate for the interactive segmentation 151 (loss, mask_loss, iou_regression_loss, model_iou, 152 sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration) 153 154 with forward_context(): 155 # 2. validate for the automatic instance segmentation 156 unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True) 157 158 loss_val += loss.item() 159 metric_val += metric.item() + (unetr_metric.item() / 3) 160 mask_loss_val += mask_loss.item() 161 iou_loss_val += iou_regression_loss.item() 162 model_iou_val += model_iou.item() 163 unetr_loss_val += unetr_loss.item() 164 val_iteration += 1 165 166 loss_val /= len(self.val_loader) 167 metric_val /= len(self.val_loader) 168 mask_loss_val /= len(self.val_loader) 169 iou_loss_val /= len(self.val_loader) 170 model_iou_val /= len(self.val_loader) 171 unetr_loss_val /= len(self.val_loader) 172 173 if self.logger is not None: 174 self.logger.log_validation( 175 self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y, 176 mask_loss_val, iou_loss_val, model_iou_val, unetr_loss_val 177 ) 178 179 return metric_val
Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
This class is inherited from SamTrainer.
Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
for details on its implementation.
Arguments:
- unetr: The UNet-style model with vision transformer as the image encoder. Required to perform automatic instance segmentation.
- instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
- instance_metric: The metric to compare the predictions and the targets.
- kwargs: The keyword arguments of the
SamTrainer(andDefaultTrainer) class.
JointSamTrainer( unetr: torch.nn.modules.module.Module, instance_loss: torch.nn.modules.module.Module, instance_metric: torch.nn.modules.module.Module, **kwargs)
Inherited Members
- micro_sam.training.sam_trainer.SamTrainer
- convert_inputs
- mse_loss
- n_objects_per_batch
- n_sub_iteration
- prompt_generator
- mask_prob
- is_data_parallel
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit