micro_sam.training.joint_sam_trainer

  1import os
  2import time
  3import numpy as np
  4from collections import OrderedDict
  5
  6import torch
  7from torch.utils.tensorboard import SummaryWriter
  8from torchvision.utils import make_grid
  9
 10from .sam_trainer import SamTrainer
 11
 12from torch_em.trainer.logger_base import TorchEmLogger
 13from torch_em.trainer.tensorboard_logger import normalize_im
 14
 15
 16class JointSamTrainer(SamTrainer):
 17    """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
 18
 19    This class is inherited from `SamTrainer`.
 20    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
 21    for details on its implementation.
 22
 23    Args:
 24        unetr: The UNet-style model with vision transformer as the image encoder.
 25            Required to perform automatic instance segmentation.
 26        instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
 27        instance_metric: The metric to compare the predictions and the targets.
 28        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
 29    """
 30    def __init__(
 31        self,
 32        unetr: torch.nn.Module,
 33        instance_loss: torch.nn.Module,
 34        instance_metric: torch.nn.Module,
 35        **kwargs
 36    ):
 37        super().__init__(**kwargs)
 38        self.unetr = unetr
 39        self.instance_loss = instance_loss
 40        self.instance_metric = instance_metric
 41
 42    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 43        current_unetr_state = self.unetr.state_dict()
 44        decoder_state = []
 45        for k, v in current_unetr_state.items():
 46            if not k.startswith("encoder"):
 47                decoder_state.append((k, v))
 48        decoder_state = OrderedDict(decoder_state)
 49
 50        super().save_checkpoint(
 51            name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict
 52        )
 53
 54    def load_checkpoint(self, checkpoint="best"):
 55        save_dict = super().load_checkpoint(checkpoint)
 56
 57        # let's get the image encoder params from sam
 58        sam_state = save_dict["model_state"]
 59        encoder_state = []
 60        prune_prefix = "sam.image_"
 61        for k, v in sam_state.items():
 62            if k.startswith(prune_prefix):
 63                encoder_state.append((k[len(prune_prefix):], v))
 64        encoder_state = OrderedDict(encoder_state)
 65
 66        # let's get the decoder params from unetr
 67        decoder_state = save_dict["decoder_state"]
 68
 69        # now let's merge the two to get the params for the unetr
 70        unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items()))
 71
 72        self.unetr.load_state_dict(unetr_state)
 73        self.unetr.to(self.device)
 74        return save_dict
 75
 76    def _instance_iteration(self, x, y, metric_for_val=False):
 77        """Perform the segmentation of distance maps and
 78        compute the loss (and metric) between the prediction and target.
 79        """
 80        outputs = self.unetr(x.to(self.device))
 81        loss = self.instance_loss(outputs, y.to(self.device))
 82        if metric_for_val:
 83            metric = self.instance_metric(outputs, y.to(self.device))
 84            return loss, metric
 85        else:
 86            return loss
 87
 88    def _train_epoch_impl(self, progress, forward_context, backprop):
 89        self.model.train()
 90
 91        input_check_done = False
 92
 93        n_iter = 0
 94        t_per_iter = time.time()
 95        for x, y in self.train_loader:
 96            labels_instances = y[:, 0, ...].unsqueeze(1)
 97            labels_for_unetr = y[:, 1:, ...]
 98
 99            input_check_done = self._check_input_normalization(x, input_check_done)
100
101            self.optimizer.zero_grad()
102
103            with forward_context():
104                # 1. train for the interactive segmentation
105                (loss, mask_loss, iou_regression_loss, model_iou,
106                 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
107
108            backprop(loss)
109
110            self.optimizer.zero_grad()
111
112            with forward_context():
113                # 2. train for the automatic instance segmentation
114                unetr_loss = self._instance_iteration(x, labels_for_unetr)
115
116            backprop(unetr_loss)
117
118            if self.logger is not None:
119                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
120                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
121                self.logger.log_train(
122                    self._iteration, loss, lr, x, labels_instances, samples,
123                    mask_loss, iou_regression_loss, model_iou, unetr_loss
124                )
125
126            self._iteration += 1
127            n_iter += 1
128            if self._iteration >= self.max_iteration:
129                break
130            progress.update(1)
131
132        t_per_iter = (time.time() - t_per_iter) / n_iter
133        return t_per_iter
134
135    def _validate_impl(self, forward_context):
136        self.model.eval()
137
138        input_check_done = False
139
140        val_iteration = 0
141        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
142
143        with torch.no_grad():
144            for x, y in self.val_loader:
145                labels_instances = y[:, 0, ...].unsqueeze(1)
146                labels_for_unetr = y[:, 1:, ...]
147
148                input_check_done = self._check_input_normalization(x, input_check_done)
149
150                with forward_context():
151                    # 1. validate for the interactive segmentation
152                    (loss, mask_loss, iou_regression_loss, model_iou,
153                     sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
154
155                with forward_context():
156                    # 2. validate for the automatic instance segmentation
157                    unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True)
158
159                loss_val += loss.item()
160                metric_val += metric.item() + (unetr_metric.item() / 3)
161                model_iou_val += model_iou.item()
162                val_iteration += 1
163
164        loss_val /= len(self.val_loader)
165        metric_val /= len(self.val_loader)
166        model_iou_val /= len(self.val_loader)
167
168        if self.logger is not None:
169            self.logger.log_validation(
170                self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
171                mask_loss, iou_regression_loss, model_iou_val, unetr_loss
172            )
173
174        return metric_val
175
176
177class JointSamLogger(TorchEmLogger):
178    """@private"""
179    def __init__(self, trainer, save_root, **unused_kwargs):
180        super().__init__(trainer, save_root)
181        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
182            os.path.join(save_root, "logs", trainer.name)
183        os.makedirs(self.log_dir, exist_ok=True)
184
185        self.tb = SummaryWriter(self.log_dir)
186        self.log_image_interval = trainer.log_image_interval
187
188    def add_image(self, x, y, samples, name, step):
189        selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]
190
191        image = normalize_im(x[selection].cpu())
192
193        self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step)
194        self.tb.add_image(tag=f"{name}/target", img_tensor=y[selection], global_step=step)
195        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
196        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
197
198    def log_train(
199            self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
200    ):
201        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
202        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
203        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
204        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
205        self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
206        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
207        if step % self.log_image_interval == 0:
208            self.add_image(x, y, samples, "train", step)
209
210    def log_validation(
211            self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou, instance_loss
212    ):
213        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
214        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
215        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
216        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
217        self.tb.add_scalar(tag="train/instance_loss", scalar_value=instance_loss, global_step=step)
218        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
219        self.add_image(x, y, samples, "validation", step)
class JointSamTrainer(micro_sam.training.sam_trainer.SamTrainer):
 17class JointSamTrainer(SamTrainer):
 18    """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.
 19
 20    This class is inherited from `SamTrainer`.
 21    Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py
 22    for details on its implementation.
 23
 24    Args:
 25        unetr: The UNet-style model with vision transformer as the image encoder.
 26            Required to perform automatic instance segmentation.
 27        instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
 28        instance_metric: The metric to compare the predictions and the targets.
 29        kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class.
 30    """
 31    def __init__(
 32        self,
 33        unetr: torch.nn.Module,
 34        instance_loss: torch.nn.Module,
 35        instance_metric: torch.nn.Module,
 36        **kwargs
 37    ):
 38        super().__init__(**kwargs)
 39        self.unetr = unetr
 40        self.instance_loss = instance_loss
 41        self.instance_metric = instance_metric
 42
 43    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
 44        current_unetr_state = self.unetr.state_dict()
 45        decoder_state = []
 46        for k, v in current_unetr_state.items():
 47            if not k.startswith("encoder"):
 48                decoder_state.append((k, v))
 49        decoder_state = OrderedDict(decoder_state)
 50
 51        super().save_checkpoint(
 52            name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict
 53        )
 54
 55    def load_checkpoint(self, checkpoint="best"):
 56        save_dict = super().load_checkpoint(checkpoint)
 57
 58        # let's get the image encoder params from sam
 59        sam_state = save_dict["model_state"]
 60        encoder_state = []
 61        prune_prefix = "sam.image_"
 62        for k, v in sam_state.items():
 63            if k.startswith(prune_prefix):
 64                encoder_state.append((k[len(prune_prefix):], v))
 65        encoder_state = OrderedDict(encoder_state)
 66
 67        # let's get the decoder params from unetr
 68        decoder_state = save_dict["decoder_state"]
 69
 70        # now let's merge the two to get the params for the unetr
 71        unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items()))
 72
 73        self.unetr.load_state_dict(unetr_state)
 74        self.unetr.to(self.device)
 75        return save_dict
 76
 77    def _instance_iteration(self, x, y, metric_for_val=False):
 78        """Perform the segmentation of distance maps and
 79        compute the loss (and metric) between the prediction and target.
 80        """
 81        outputs = self.unetr(x.to(self.device))
 82        loss = self.instance_loss(outputs, y.to(self.device))
 83        if metric_for_val:
 84            metric = self.instance_metric(outputs, y.to(self.device))
 85            return loss, metric
 86        else:
 87            return loss
 88
 89    def _train_epoch_impl(self, progress, forward_context, backprop):
 90        self.model.train()
 91
 92        input_check_done = False
 93
 94        n_iter = 0
 95        t_per_iter = time.time()
 96        for x, y in self.train_loader:
 97            labels_instances = y[:, 0, ...].unsqueeze(1)
 98            labels_for_unetr = y[:, 1:, ...]
 99
100            input_check_done = self._check_input_normalization(x, input_check_done)
101
102            self.optimizer.zero_grad()
103
104            with forward_context():
105                # 1. train for the interactive segmentation
106                (loss, mask_loss, iou_regression_loss, model_iou,
107                 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
108
109            backprop(loss)
110
111            self.optimizer.zero_grad()
112
113            with forward_context():
114                # 2. train for the automatic instance segmentation
115                unetr_loss = self._instance_iteration(x, labels_for_unetr)
116
117            backprop(unetr_loss)
118
119            if self.logger is not None:
120                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
121                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
122                self.logger.log_train(
123                    self._iteration, loss, lr, x, labels_instances, samples,
124                    mask_loss, iou_regression_loss, model_iou, unetr_loss
125                )
126
127            self._iteration += 1
128            n_iter += 1
129            if self._iteration >= self.max_iteration:
130                break
131            progress.update(1)
132
133        t_per_iter = (time.time() - t_per_iter) / n_iter
134        return t_per_iter
135
136    def _validate_impl(self, forward_context):
137        self.model.eval()
138
139        input_check_done = False
140
141        val_iteration = 0
142        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
143
144        with torch.no_grad():
145            for x, y in self.val_loader:
146                labels_instances = y[:, 0, ...].unsqueeze(1)
147                labels_for_unetr = y[:, 1:, ...]
148
149                input_check_done = self._check_input_normalization(x, input_check_done)
150
151                with forward_context():
152                    # 1. validate for the interactive segmentation
153                    (loss, mask_loss, iou_regression_loss, model_iou,
154                     sampled_binary_y, metric) = self._interactive_val_iteration(x, labels_instances, val_iteration)
155
156                with forward_context():
157                    # 2. validate for the automatic instance segmentation
158                    unetr_loss, unetr_metric = self._instance_iteration(x, labels_for_unetr, metric_for_val=True)
159
160                loss_val += loss.item()
161                metric_val += metric.item() + (unetr_metric.item() / 3)
162                model_iou_val += model_iou.item()
163                val_iteration += 1
164
165        loss_val /= len(self.val_loader)
166        metric_val /= len(self.val_loader)
167        model_iou_val /= len(self.val_loader)
168
169        if self.logger is not None:
170            self.logger.log_validation(
171                self._iteration, metric_val, loss_val, x, labels_instances, sampled_binary_y,
172                mask_loss, iou_regression_loss, model_iou_val, unetr_loss
173            )
174
175        return metric_val

Trainer class for jointly training the Segment Anything model with an additional convolutional decoder.

This class is inherited from SamTrainer. Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py for details on its implementation.

Arguments:
  • unetr: The UNet-style model with vision transformer as the image encoder. Required to perform automatic instance segmentation.
  • instance_loss: The loss to compare the predictions (for instance segmentation) and the targets.
  • instance_metric: The metric to compare the predictions and the targets.
  • kwargs: The keyword arguments of the SamTrainer (and DefaultTrainer) class.
JointSamTrainer( unetr: torch.nn.modules.module.Module, instance_loss: torch.nn.modules.module.Module, instance_metric: torch.nn.modules.module.Module, **kwargs)
31    def __init__(
32        self,
33        unetr: torch.nn.Module,
34        instance_loss: torch.nn.Module,
35        instance_metric: torch.nn.Module,
36        **kwargs
37    ):
38        super().__init__(**kwargs)
39        self.unetr = unetr
40        self.instance_loss = instance_loss
41        self.instance_metric = instance_metric
unetr
instance_loss
instance_metric
def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
43    def save_checkpoint(self, name, current_metric, best_metric, **extra_save_dict):
44        current_unetr_state = self.unetr.state_dict()
45        decoder_state = []
46        for k, v in current_unetr_state.items():
47            if not k.startswith("encoder"):
48                decoder_state.append((k, v))
49        decoder_state = OrderedDict(decoder_state)
50
51        super().save_checkpoint(
52            name, current_metric=current_metric, best_metric=best_metric, decoder_state=decoder_state, **extra_save_dict
53        )
def load_checkpoint(self, checkpoint='best'):
55    def load_checkpoint(self, checkpoint="best"):
56        save_dict = super().load_checkpoint(checkpoint)
57
58        # let's get the image encoder params from sam
59        sam_state = save_dict["model_state"]
60        encoder_state = []
61        prune_prefix = "sam.image_"
62        for k, v in sam_state.items():
63            if k.startswith(prune_prefix):
64                encoder_state.append((k[len(prune_prefix):], v))
65        encoder_state = OrderedDict(encoder_state)
66
67        # let's get the decoder params from unetr
68        decoder_state = save_dict["decoder_state"]
69
70        # now let's merge the two to get the params for the unetr
71        unetr_state = OrderedDict(list(encoder_state.items()) + list(decoder_state.items()))
72
73        self.unetr.load_state_dict(unetr_state)
74        self.unetr.to(self.device)
75        return save_dict
Inherited Members
micro_sam.training.sam_trainer.SamTrainer
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
fit