micro_sam.training.semantic_sam_trainer

  1import time
  2from typing import Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from torch_em.loss import DiceLoss
  8from torch_em.trainer import DefaultTrainer
  9
 10
 11class CustomDiceLoss(nn.Module):
 12    """Loss for computing dice over one-hot labels.
 13
 14    Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation.
 15
 16    Args:
 17        num_classes: The number of classes for semantic segmentation (including background class).
 18        softmax: Whether to use softmax over the predictions.
 19    """
 20    def __init__(self, num_classes: int, softmax: bool = True) -> None:
 21        super().__init__()
 22        self.num_classes = num_classes
 23        self.dice_loss = DiceLoss()
 24        self.softmax = softmax
 25
 26    def _one_hot_encoder(self, input_tensor):
 27        tensor_list = []
 28        for i in range(self.num_classes):
 29            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
 30            tensor_list.append(temp_prob)
 31        output_tensor = torch.cat(tensor_list, dim=1)
 32        return output_tensor.float()
 33
 34    def __call__(self, pred, target):
 35        if self.softmax:
 36            pred = torch.softmax(pred, dim=1)
 37        target = self._one_hot_encoder(target)
 38        loss = self.dice_loss(pred, target)
 39        return loss
 40
 41
 42class SemanticSamTrainer(DefaultTrainer):
 43    """Trainer class for training the Segment Anything model for semantic segmentation.
 44
 45    This class is derived from `torch_em.trainer.DefaultTrainer`.
 46    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 47    for details on its usage and implementation.
 48
 49    Args:
 50        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 51            The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here.
 52        num_classes: The number of classes for semantic segmentation (including the background class).
 53        dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
 54        kwargs: The keyword arguments of the DefaultTrainer super class.
 55    """
 56    def __init__(
 57        self,
 58        convert_inputs,
 59        num_classes: int,
 60        dice_weight: Optional[float] = None,
 61        **kwargs
 62    ):
 63        assert num_classes > 1
 64
 65        if "loss" not in kwargs:
 66            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
 67
 68        if "metric" not in kwargs:
 69            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
 70
 71        super().__init__(**kwargs)
 72
 73        self.convert_inputs = convert_inputs
 74        self.num_classes = num_classes
 75        self.compute_ce_loss = nn.CrossEntropyLoss()
 76        self.dice_weight = dice_weight
 77
 78        if self.dice_weight is not None:
 79            assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1."
 80
 81        self._kwargs = kwargs
 82
 83    def _compute_loss(self, y, masks):
 84        """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target.
 85        """
 86        target = y.to(self.device, non_blocking=True)
 87        # Compute dice loss for the predictions
 88        dice_loss = self.loss(masks, target)
 89
 90        # Compute cross entropy loss for the predictions
 91        ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long())
 92
 93        if self.dice_weight is None:
 94            net_loss = dice_loss + ce_loss
 95        else:
 96            net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss
 97
 98        return net_loss
 99
100    def _get_model_outputs(self, batched_inputs):
101        """Get the predictions from the model.
102        """
103        # Precompute the image embeddings if the model exposes it as functionality.
104        if hasattr(self.model, "image_embeddings_oft"):
105            image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
106            batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True)
107        else:  # Otherwise we assume that the embeddings are computed internally as part of the forward pass.
108            # We need to take care of sending things to the device here.
109            batched_inputs = [
110                {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]}
111                for inp in batched_inputs
112            ]
113            batched_outputs = self.model(batched_inputs, multimask_output=True)
114
115        masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
116        return masks
117
118    def _train_epoch_impl(self, progress, forward_context, backprop):
119        self.model.train()
120
121        t_per_iter = time.time()
122        for x, y in self.train_loader:
123            self.optimizer.zero_grad()
124
125            batched_inputs = self.convert_inputs(x, y)
126
127            with forward_context():
128                masks = self._get_model_outputs(batched_inputs)
129                net_loss = self._compute_loss(y, masks)
130
131            backprop(net_loss)
132
133            self._iteration += 1
134
135            if self.logger is not None:
136                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
137                self.logger.log_train(
138                    self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False
139                )
140
141            if self._iteration >= self.max_iteration:
142                break
143            progress.update(1)
144
145        t_per_iter = (time.time() - t_per_iter)
146        return t_per_iter
147
148    def _validate_impl(self, forward_context):
149        self.model.eval()
150
151        metric_val, loss_val = 0.0, 0.0
152
153        with torch.no_grad():
154            for x, y in self.val_loader:
155                batched_inputs = self.convert_inputs(x, y)
156
157                with forward_context():
158                    masks = self._get_model_outputs(batched_inputs)
159                    net_loss = self._compute_loss(y, masks)
160
161                loss_val += net_loss.item()
162                metric_val += net_loss.item()
163
164        loss_val /= len(self.val_loader)
165        metric_val /= len(self.val_loader)
166        dice_metric = 1 - (metric_val / self.num_classes)
167        print()
168        print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")
169
170        if self.logger is not None:
171            self.logger.log_validation(
172                self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)
173            )
174
175        return metric_val
176
177
178class SemanticMapsSamTrainer(SemanticSamTrainer):
179    def _compute_loss(self, y, masks):
180        target = y.to(self.device, non_blocking=True)
181
182        # Compute loss for the predictions
183        net_loss = self.loss(target, masks)
184
185        return net_loss
class CustomDiceLoss(torch.nn.modules.module.Module):
12class CustomDiceLoss(nn.Module):
13    """Loss for computing dice over one-hot labels.
14
15    Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation.
16
17    Args:
18        num_classes: The number of classes for semantic segmentation (including background class).
19        softmax: Whether to use softmax over the predictions.
20    """
21    def __init__(self, num_classes: int, softmax: bool = True) -> None:
22        super().__init__()
23        self.num_classes = num_classes
24        self.dice_loss = DiceLoss()
25        self.softmax = softmax
26
27    def _one_hot_encoder(self, input_tensor):
28        tensor_list = []
29        for i in range(self.num_classes):
30            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
31            tensor_list.append(temp_prob)
32        output_tensor = torch.cat(tensor_list, dim=1)
33        return output_tensor.float()
34
35    def __call__(self, pred, target):
36        if self.softmax:
37            pred = torch.softmax(pred, dim=1)
38        target = self._one_hot_encoder(target)
39        loss = self.dice_loss(pred, target)
40        return loss

Loss for computing dice over one-hot labels.

Expects prediction and target with num_classes channels: the number of classes for semantic segmentation.

Arguments:
  • num_classes: The number of classes for semantic segmentation (including background class).
  • softmax: Whether to use softmax over the predictions.
CustomDiceLoss(num_classes: int, softmax: bool = True)
21    def __init__(self, num_classes: int, softmax: bool = True) -> None:
22        super().__init__()
23        self.num_classes = num_classes
24        self.dice_loss = DiceLoss()
25        self.softmax = softmax

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
dice_loss
softmax
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SemanticSamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 43class SemanticSamTrainer(DefaultTrainer):
 44    """Trainer class for training the Segment Anything model for semantic segmentation.
 45
 46    This class is derived from `torch_em.trainer.DefaultTrainer`.
 47    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 48    for details on its usage and implementation.
 49
 50    Args:
 51        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 52            The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here.
 53        num_classes: The number of classes for semantic segmentation (including the background class).
 54        dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
 55        kwargs: The keyword arguments of the DefaultTrainer super class.
 56    """
 57    def __init__(
 58        self,
 59        convert_inputs,
 60        num_classes: int,
 61        dice_weight: Optional[float] = None,
 62        **kwargs
 63    ):
 64        assert num_classes > 1
 65
 66        if "loss" not in kwargs:
 67            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
 68
 69        if "metric" not in kwargs:
 70            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
 71
 72        super().__init__(**kwargs)
 73
 74        self.convert_inputs = convert_inputs
 75        self.num_classes = num_classes
 76        self.compute_ce_loss = nn.CrossEntropyLoss()
 77        self.dice_weight = dice_weight
 78
 79        if self.dice_weight is not None:
 80            assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1."
 81
 82        self._kwargs = kwargs
 83
 84    def _compute_loss(self, y, masks):
 85        """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target.
 86        """
 87        target = y.to(self.device, non_blocking=True)
 88        # Compute dice loss for the predictions
 89        dice_loss = self.loss(masks, target)
 90
 91        # Compute cross entropy loss for the predictions
 92        ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long())
 93
 94        if self.dice_weight is None:
 95            net_loss = dice_loss + ce_loss
 96        else:
 97            net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss
 98
 99        return net_loss
100
101    def _get_model_outputs(self, batched_inputs):
102        """Get the predictions from the model.
103        """
104        # Precompute the image embeddings if the model exposes it as functionality.
105        if hasattr(self.model, "image_embeddings_oft"):
106            image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
107            batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True)
108        else:  # Otherwise we assume that the embeddings are computed internally as part of the forward pass.
109            # We need to take care of sending things to the device here.
110            batched_inputs = [
111                {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]}
112                for inp in batched_inputs
113            ]
114            batched_outputs = self.model(batched_inputs, multimask_output=True)
115
116        masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
117        return masks
118
119    def _train_epoch_impl(self, progress, forward_context, backprop):
120        self.model.train()
121
122        t_per_iter = time.time()
123        for x, y in self.train_loader:
124            self.optimizer.zero_grad()
125
126            batched_inputs = self.convert_inputs(x, y)
127
128            with forward_context():
129                masks = self._get_model_outputs(batched_inputs)
130                net_loss = self._compute_loss(y, masks)
131
132            backprop(net_loss)
133
134            self._iteration += 1
135
136            if self.logger is not None:
137                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
138                self.logger.log_train(
139                    self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False
140                )
141
142            if self._iteration >= self.max_iteration:
143                break
144            progress.update(1)
145
146        t_per_iter = (time.time() - t_per_iter)
147        return t_per_iter
148
149    def _validate_impl(self, forward_context):
150        self.model.eval()
151
152        metric_val, loss_val = 0.0, 0.0
153
154        with torch.no_grad():
155            for x, y in self.val_loader:
156                batched_inputs = self.convert_inputs(x, y)
157
158                with forward_context():
159                    masks = self._get_model_outputs(batched_inputs)
160                    net_loss = self._compute_loss(y, masks)
161
162                loss_val += net_loss.item()
163                metric_val += net_loss.item()
164
165        loss_val /= len(self.val_loader)
166        metric_val /= len(self.val_loader)
167        dice_metric = 1 - (metric_val / self.num_classes)
168        print()
169        print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")
170
171        if self.logger is not None:
172            self.logger.log_validation(
173                self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)
174            )
175
176        return metric_val

Trainer class for training the Segment Anything model for semantic segmentation.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSemanticSamInputs can be used here.
  • num_classes: The number of classes for semantic segmentation (including the background class).
  • dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SemanticSamTrainer( convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs)
57    def __init__(
58        self,
59        convert_inputs,
60        num_classes: int,
61        dice_weight: Optional[float] = None,
62        **kwargs
63    ):
64        assert num_classes > 1
65
66        if "loss" not in kwargs:
67            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
68
69        if "metric" not in kwargs:
70            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
71
72        super().__init__(**kwargs)
73
74        self.convert_inputs = convert_inputs
75        self.num_classes = num_classes
76        self.compute_ce_loss = nn.CrossEntropyLoss()
77        self.dice_weight = dice_weight
78
79        if self.dice_weight is not None:
80            assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1."
81
82        self._kwargs = kwargs
convert_inputs
num_classes
compute_ce_loss
dice_weight
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
save_checkpoint
load_checkpoint
fit
class SemanticMapsSamTrainer(SemanticSamTrainer):
179class SemanticMapsSamTrainer(SemanticSamTrainer):
180    def _compute_loss(self, y, masks):
181        target = y.to(self.device, non_blocking=True)
182
183        # Compute loss for the predictions
184        net_loss = self.loss(target, masks)
185
186        return net_loss

Trainer class for training the Segment Anything model for semantic segmentation.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSemanticSamInputs can be used here.
  • num_classes: The number of classes for semantic segmentation (including the background class).
  • dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
Inherited Members
SemanticSamTrainer
SemanticSamTrainer
convert_inputs
num_classes
compute_ce_loss
dice_weight
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
save_checkpoint
load_checkpoint
fit