micro_sam.training.joint_sam_trainer

  1import os
  2import time
  3import numpy as np
  4from collections import OrderedDict
  5
  6import torch
  7from torch.utils.tensorboard import SummaryWriter
  8from torchvision.utils import make_grid
  9
 10from .sam_trainer import SamTrainer
 11
 12from torch_em.trainer.logger_base import TorchEmLogger
 13from torch_em.trainer.tensorboard_logger import normalize_im
 14
 15
 16class JointSamTrainer(SamTrainer):
 17    """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
 18
 19    This class is inherited from `SamTrainer`.
 20    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
 21    for details on its implementation.
 22
 23    Args:
 24        unetr: The UNet-style model with vision transformer as the image encoder.
 25            Required to perform automatic instance segmentation.
 26        instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
 27        instance_metric: The metric to compare the predictions and the targets.
 28        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
 29    """
 30
 31    def __init__(
 32        self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs
 33    ):
 34        super().__init__(**kwargs)
 35        self.unetr = unetr
 36        self.instance_loss = instance_loss
 37        self.instance_metric = instance_metric
 38
 39    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 40        current_unetr_state = self.unetr.state_dict()
 41        decoder_state = []
 42        for k, v in current_unetr_state.items():
 43            if not k.startswith("encoder"):
 44                decoder_state.append((k, v))
 45        decoder_state = OrderedDict(decoder_state)
 46
 47        super().save_checkpoint(
 48            name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict
 49        )
 50
 51    def load_checkpoint(self, checkpoint="best"):
 52        save_dict = super().load_checkpoint(checkpoint)
 53
 54        # let's get the image encoder params from sam
 55        sam_state = save_dict["model_state"]
 56        encoder_state = []
 57        prune_prefix = "sam.image_"
 58        for k, v in sam_state.items():
 59            if k.startswith(prune_prefix):
 60                encoder_state.append((k[len(prune_prefix):], v))
 61        encoder_state = OrderedDict(encoder_state)
 62
 63        # let's get the decoder params from unetr
 64        decoder_state = save_dict["decoder_state"]
 65
 66        # now let's merge the two to get the params for the unetr
 67        unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items()))
 68
 69        self.unetr.load_state_dict(unetr_state)
 70        self.unetr.to(self.device)
 71        return save_dict
 72
 73    def _instance_iteration(self, x, y, metric_for_val=False):
 74        """Perform the segmentation of distance maps and
 75        compute the loss (and metric) between the prediction and target.
 76        """
 77        outputs = self.unetr(x.to(self.device))
 78        loss = self.instance_loss(outputs, y.to(self.device))
 79        if metric_for_val:
 80            metric = self.instance_metric(outputs, y.to(self.device))
 81            return loss, metric
 82        else:
 83            return loss
 84
 85    def _train_epoch_impl(self, progress, forward_context, backprop):
 86        self.model.train()
 87
 88        input_check_done = False
 89
 90        n_iter = 0
 91        t_per_iter = time.time()
 92        for x, y in self.train_loader:
 93            labels_instances = y[:, 0, ...].unsqueeze(1)
 94            labels_for_unetr = y[:, 1:, ...]
 95
 96            input_check_done = self._check_input_normalization(x, input_check_done)
 97
 98            self.optimizer.zero_grad()
 99
100            with forward_context():
101                # 1. train for the interactive segmentation
102                (loss, mask_loss, iou_regression_loss, model_iou,
103                 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
104
105            backprop(loss)
106
107            self.optimizer.zero_grad()
108
109            with forward_context():
110                # 2. train for the automatic instance segmentation
111                unetr_loss = self._instance_iteration(x, labels_for_unetr)
112
113            backprop(unetr_loss)
114
115            if self.logger is not None:
116                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
117                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
118                self.logger.log_train(
119                    self._iteration, loss, lr, x, labels_instances, samples,
120                    mask_loss, iou_regression_loss, model_iou, unetr_loss
121                )
122
123            self._iteration += 1
124            n_iter += 1
125            if self._iteration >= self.max_iteration:
126                break
127            progress.update(1)
128
129        t_per_iter = (time.time() - t_per_iter) / n_iter
130        return t_per_iter
131
132    def _validate_impl(self, forward_context):
133        self.model.eval()
134
135        input_check_done = False
136
137        val_iteration = 0
138        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
139
140        with torch.no_grad():
141            for x, y in self.val_loader:
142                labels_instances = y[:, 0, ...].unsqueeze(1)
143                labels_for_unetr = y[:, 1:, ...]
144
145                input_check_done = self._check_input_normalization(x, input_check_done)
146
147                with forward_context():
148                    # 1. validate for the interactive segmentation
149                    (loss, mask_loss, iou_regression_loss, model_iou,
150                     sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
151
152                with forward_context():
153                    # 2. validate for the automatic instance segmentation
154                    unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True)
155
156                loss_val += loss.item()
157                metric_val += metric.item() + (unetr_metric.item() / 3)
158                model_iou_val += model_iou.item()
159                val_iteration += 1
160
161        loss_val /= len(self.val_loader)
162        metric_val /= len(self.val_loader)
163        model_iou_val /= len(self.val_loader)
164
165        if self.logger is not None:
166            self.logger.log_validation(
167                self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
168                mask_loss, iou_regression_loss, model_iou_val, unetr_loss
169            )
170
171        return metric_val
172
173
174class JointSamLogger(TorchEmLogger):
175    """@private"""
176    def __init__(self, trainer, save_root, **unused_kwargs):
177        super().__init__(trainer, save_root)
178        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
179            os.path.join(save_root, "logs", trainer.name)
180        os.makedirs(self.log_dir, exist_ok=True)
181
182        self.tb = SummaryWriter(self.log_dir)
183        self.log_image_interval = trainer.log_image_interval
184
185    def add_image(self, x, y, samples, name, step):
186        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
187
188        image = normalize_im(x[selection].cpu())
189
190        self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
191        self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step)
192        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
193        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
194
195    def log_train(
196        self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
197    ):
198        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
199        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
200        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
201        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
202        self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
203        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
204        if step % self.log_image_interval == 0:
205            self.add_image(x, y, samples, "train", step)
206
207    def log_validation(
208        self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
209    ):
210        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
211        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
212        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
213        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
214        self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
215        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
216        self.add_image(x, y, samples, "validation", step)
class JointSamTrainer(micro_sam.training.sam_trainer.SamTrainer):
 17class JointSamTrainer(SamTrainer):
 18    """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
 19
 20    This class is inherited from `SamTrainer`.
 21    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
 22    for details on its implementation.
 23
 24    Args:
 25        unetr: The UNet-style model with vision transformer as the image encoder.
 26            Required to perform automatic instance segmentation.
 27        instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
 28        instance_metric: The metric to compare the predictions and the targets.
 29        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
 30    """
 31
 32    def __init__(
 33        self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs
 34    ):
 35        super().__init__(**kwargs)
 36        self.unetr = unetr
 37        self.instance_loss = instance_loss
 38        self.instance_metric = instance_metric
 39
 40    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 41        current_unetr_state = self.unetr.state_dict()
 42        decoder_state = []
 43        for k, v in current_unetr_state.items():
 44            if not k.startswith("encoder"):
 45                decoder_state.append((k, v))
 46        decoder_state = OrderedDict(decoder_state)
 47
 48        super().save_checkpoint(
 49            name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict
 50        )
 51
 52    def load_checkpoint(self, checkpoint="best"):
 53        save_dict = super().load_checkpoint(checkpoint)
 54
 55        # let's get the image encoder params from sam
 56        sam_state = save_dict["model_state"]
 57        encoder_state = []
 58        prune_prefix = "sam.image_"
 59        for k, v in sam_state.items():
 60            if k.startswith(prune_prefix):
 61                encoder_state.append((k[len(prune_prefix):], v))
 62        encoder_state = OrderedDict(encoder_state)
 63
 64        # let's get the decoder params from unetr
 65        decoder_state = save_dict["decoder_state"]
 66
 67        # now let's merge the two to get the params for the unetr
 68        unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items()))
 69
 70        self.unetr.load_state_dict(unetr_state)
 71        self.unetr.to(self.device)
 72        return save_dict
 73
 74    def _instance_iteration(self, x, y, metric_for_val=False):
 75        """Perform the segmentation of distance maps and
 76        compute the loss (and metric) between the prediction and target.
 77        """
 78        outputs = self.unetr(x.to(self.device))
 79        loss = self.instance_loss(outputs, y.to(self.device))
 80        if metric_for_val:
 81            metric = self.instance_metric(outputs, y.to(self.device))
 82            return loss, metric
 83        else:
 84            return loss
 85
 86    def _train_epoch_impl(self, progress, forward_context, backprop):
 87        self.model.train()
 88
 89        input_check_done = False
 90
 91        n_iter = 0
 92        t_per_iter = time.time()
 93        for x, y in self.train_loader:
 94            labels_instances = y[:, 0, ...].unsqueeze(1)
 95            labels_for_unetr = y[:, 1:, ...]
 96
 97            input_check_done = self._check_input_normalization(x, input_check_done)
 98
 99            self.optimizer.zero_grad()
100
101            with forward_context():
102                # 1. train for the interactive segmentation
103                (loss, mask_loss, iou_regression_loss, model_iou,
104                 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
105
106            backprop(loss)
107
108            self.optimizer.zero_grad()
109
110            with forward_context():
111                # 2. train for the automatic instance segmentation
112                unetr_loss = self._instance_iteration(x, labels_for_unetr)
113
114            backprop(unetr_loss)
115
116            if self.logger is not None:
117                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
118                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
119                self.logger.log_train(
120                    self._iteration, loss, lr, x, labels_instances, samples,
121                    mask_loss, iou_regression_loss, model_iou, unetr_loss
122                )
123
124            self._iteration += 1
125            n_iter += 1
126            if self._iteration >= self.max_iteration:
127                break
128            progress.update(1)
129
130        t_per_iter = (time.time() - t_per_iter) / n_iter
131        return t_per_iter
132
133    def _validate_impl(self, forward_context):
134        self.model.eval()
135
136        input_check_done = False
137
138        val_iteration = 0
139        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
140
141        with torch.no_grad():
142            for x, y in self.val_loader:
143                labels_instances = y[:, 0, ...].unsqueeze(1)
144                labels_for_unetr = y[:, 1:, ...]
145
146                input_check_done = self._check_input_normalization(x, input_check_done)
147
148                with forward_context():
149                    # 1. validate for the interactive segmentation
150                    (loss, mask_loss, iou_regression_loss, model_iou,
151                     sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
152
153                with forward_context():
154                    # 2. validate for the automatic instance segmentation
155                    unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True)
156
157                loss_val += loss.item()
158                metric_val += metric.item() + (unetr_metric.item() / 3)
159                model_iou_val += model_iou.item()
160                val_iteration += 1
161
162        loss_val /= len(self.val_loader)
163        metric_val /= len(self.val_loader)
164        model_iou_val /= len(self.val_loader)
165
166        if self.logger is not None:
167            self.logger.log_validation(
168                self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
169                mask_loss, iou_regression_loss, model_iou_val, unetr_loss
170            )
171
172        return metric_val

Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.

This class is inherited from SamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py for details on its implementation.

Arguments:
  • unetr: The UNet-style model with vision transformer as the image encoder. Required to perform automatic instance segmentation.
  • instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
  • instance_metric: The metric to compare the predictions and the targets.
  • kwargs: The keyword arguments of the SamTrainer (and DefaultTrainer) class.
JointSamTrainer( unetr: torch.nn.modules.module.Module, instance_loss: torch.nn.modules.module.Module, instance_metric: torch.nn.modules.module.Module, **kwargs)
32    def __init__(
33        self, unetr: torch.nn.Module, instance_loss: torch.nn.Module, instance_metric: torch.nn.Module, **kwargs
34    ):
35        super().__init__(**kwargs)
36        self.unetr = unetr
37        self.instance_loss = instance_loss
38        self.instance_metric = instance_metric
unetr
instance_loss
instance_metric
Inherited Members
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit