micro_sam.training.semantic_sam_trainer

  1import time
  2from typing import Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from torch_em.loss import DiceLoss
  8from torch_em.trainer import DefaultTrainer
  9
 10
 11class CustomDiceLoss(nn.Module):
 12    """Loss for computing dice over one-hot labels.
 13
 14    Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation.
 15
 16    Args:
 17        num_classes: The number of classes for semantic segmentation (including background class).
 18        softmax: Whether to use softmax over the predictions.
 19    """
 20    def __init__(self, num_classes: int, softmax: bool = True) -> None:
 21        super().__init__()
 22        self.num_classes = num_classes
 23        self.dice_loss = DiceLoss()
 24        self.softmax = softmax
 25
 26    def _one_hot_encoder(self, input_tensor):
 27        tensor_list = []
 28        for i in range(self.num_classes):
 29            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
 30            tensor_list.append(temp_prob)
 31        output_tensor = torch.cat(tensor_list, dim=1)
 32        return output_tensor.float()
 33
 34    def __call__(self, pred, target):
 35        if self.softmax:
 36            pred = torch.softmax(pred, dim=1)
 37        target = self._one_hot_encoder(target)
 38        loss = self.dice_loss(pred, target)
 39        return loss
 40
 41
 42class SemanticSamTrainer(DefaultTrainer):
 43    """Trainer class for training the Segment Anything model for semantic segmentation.
 44
 45    This class is derived from `torch_em.trainer.DefaultTrainer`.
 46    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 47    for details on its usage and implementation.
 48
 49    Args:
 50        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 51            The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here.
 52        num_classes: The number of classes for semantic segmentation (including the background class).
 53        dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
 54        kwargs: The keyword arguments of the DefaultTrainer super class.
 55    """
 56
 57    def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs):
 58        assert num_classes > 1
 59
 60        if "loss" not in kwargs:
 61            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
 62
 63        if "metric" not in kwargs:
 64            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
 65
 66        super().__init__(**kwargs)
 67
 68        self.convert_inputs = convert_inputs
 69        self.num_classes = num_classes
 70        self.compute_ce_loss = nn.CrossEntropyLoss()
 71        self.dice_weight = dice_weight
 72
 73        if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1):
 74            raise ValueError("The weight factor should lie between 0 and 1.")
 75
 76        self._kwargs = kwargs
 77
 78    def _compute_loss(self, y, masks):
 79        """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target.
 80        """
 81        target = y.to(self.device, non_blocking=True)
 82        # Compute dice loss for the predictions
 83        dice_loss = self.loss(masks, target)
 84
 85        # Compute cross entropy loss for the predictions
 86        ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long())
 87
 88        if self.dice_weight is None:
 89            net_loss = dice_loss + ce_loss
 90        else:
 91            net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss
 92
 93        return net_loss
 94
 95    def _get_model_outputs(self, batched_inputs):
 96        """Get the predictions from the model.
 97        """
 98        # Precompute the image embeddings if the model exposes it as functionality.
 99        if hasattr(self.model, "image_embeddings_oft"):
100            image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
101            batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True)
102        else:  # Otherwise we assume that the embeddings are computed internally as part of the forward pass.
103            # We need to take care of sending things to the device here.
104            batched_inputs = [
105                {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]}
106                for inp in batched_inputs
107            ]
108            batched_outputs = self.model(batched_inputs, multimask_output=True)
109
110        masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
111        return masks
112
113    def _train_epoch_impl(self, progress, forward_context, backprop):
114        self.model.train()
115
116        t_per_iter = time.time()
117        for x, y in self.train_loader:
118            self.optimizer.zero_grad()
119
120            batched_inputs = self.convert_inputs(x, y)
121
122            with forward_context():
123                masks = self._get_model_outputs(batched_inputs)
124                net_loss = self._compute_loss(y, masks)
125
126            backprop(net_loss)
127
128            self._iteration += 1
129
130            if self.logger is not None:
131                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
132                self.logger.log_train(
133                    self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False
134                )
135
136            if self._iteration >= self.max_iteration:
137                break
138            progress.update(1)
139
140        t_per_iter = (time.time() - t_per_iter)
141        return t_per_iter
142
143    def _validate_impl(self, forward_context):
144        self.model.eval()
145
146        metric_val, loss_val = 0.0, 0.0
147
148        with torch.no_grad():
149            for x, y in self.val_loader:
150                batched_inputs = self.convert_inputs(x, y)
151
152                with forward_context():
153                    masks = self._get_model_outputs(batched_inputs)
154                    net_loss = self._compute_loss(y, masks)
155
156                loss_val += net_loss.item()
157                metric_val += net_loss.item()
158
159        loss_val /= len(self.val_loader)
160        metric_val /= len(self.val_loader)
161        dice_metric = 1 - (metric_val / self.num_classes)
162        print()
163        print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")
164
165        if self.logger is not None:
166            self.logger.log_validation(
167                self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)
168            )
169
170        return metric_val
171
172
173class SemanticMapsSamTrainer(SemanticSamTrainer):
174    def _compute_loss(self, y, masks):
175        target = y.to(self.device, non_blocking=True)
176
177        # Compute loss for the predictions
178        net_loss = self.loss(target, masks)
179
180        return net_loss
class CustomDiceLoss(torch.nn.modules.module.Module):
12class CustomDiceLoss(nn.Module):
13    """Loss for computing dice over one-hot labels.
14
15    Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation.
16
17    Args:
18        num_classes: The number of classes for semantic segmentation (including background class).
19        softmax: Whether to use softmax over the predictions.
20    """
21    def __init__(self, num_classes: int, softmax: bool = True) -> None:
22        super().__init__()
23        self.num_classes = num_classes
24        self.dice_loss = DiceLoss()
25        self.softmax = softmax
26
27    def _one_hot_encoder(self, input_tensor):
28        tensor_list = []
29        for i in range(self.num_classes):
30            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
31            tensor_list.append(temp_prob)
32        output_tensor = torch.cat(tensor_list, dim=1)
33        return output_tensor.float()
34
35    def __call__(self, pred, target):
36        if self.softmax:
37            pred = torch.softmax(pred, dim=1)
38        target = self._one_hot_encoder(target)
39        loss = self.dice_loss(pred, target)
40        return loss

Loss for computing dice over one-hot labels.

Expects prediction and target with num_classes channels: the number of classes for semantic segmentation.

Arguments:
  • num_classes: The number of classes for semantic segmentation (including background class).
  • softmax: Whether to use softmax over the predictions.
CustomDiceLoss(num_classes: int, softmax: bool = True)
21    def __init__(self, num_classes: int, softmax: bool = True) -> None:
22        super().__init__()
23        self.num_classes = num_classes
24        self.dice_loss = DiceLoss()
25        self.softmax = softmax

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
dice_loss
softmax
Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
forward
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SemanticSamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 43class SemanticSamTrainer(DefaultTrainer):
 44    """Trainer class for training the Segment Anything model for semantic segmentation.
 45
 46    This class is derived from `torch_em.trainer.DefaultTrainer`.
 47    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 48    for details on its usage and implementation.
 49
 50    Args:
 51        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 52            The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here.
 53        num_classes: The number of classes for semantic segmentation (including the background class).
 54        dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
 55        kwargs: The keyword arguments of the DefaultTrainer super class.
 56    """
 57
 58    def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs):
 59        assert num_classes > 1
 60
 61        if "loss" not in kwargs:
 62            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
 63
 64        if "metric" not in kwargs:
 65            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
 66
 67        super().__init__(**kwargs)
 68
 69        self.convert_inputs = convert_inputs
 70        self.num_classes = num_classes
 71        self.compute_ce_loss = nn.CrossEntropyLoss()
 72        self.dice_weight = dice_weight
 73
 74        if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1):
 75            raise ValueError("The weight factor should lie between 0 and 1.")
 76
 77        self._kwargs = kwargs
 78
 79    def _compute_loss(self, y, masks):
 80        """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target.
 81        """
 82        target = y.to(self.device, non_blocking=True)
 83        # Compute dice loss for the predictions
 84        dice_loss = self.loss(masks, target)
 85
 86        # Compute cross entropy loss for the predictions
 87        ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long())
 88
 89        if self.dice_weight is None:
 90            net_loss = dice_loss + ce_loss
 91        else:
 92            net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss
 93
 94        return net_loss
 95
 96    def _get_model_outputs(self, batched_inputs):
 97        """Get the predictions from the model.
 98        """
 99        # Precompute the image embeddings if the model exposes it as functionality.
100        if hasattr(self.model, "image_embeddings_oft"):
101            image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
102            batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True)
103        else:  # Otherwise we assume that the embeddings are computed internally as part of the forward pass.
104            # We need to take care of sending things to the device here.
105            batched_inputs = [
106                {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]}
107                for inp in batched_inputs
108            ]
109            batched_outputs = self.model(batched_inputs, multimask_output=True)
110
111        masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
112        return masks
113
114    def _train_epoch_impl(self, progress, forward_context, backprop):
115        self.model.train()
116
117        t_per_iter = time.time()
118        for x, y in self.train_loader:
119            self.optimizer.zero_grad()
120
121            batched_inputs = self.convert_inputs(x, y)
122
123            with forward_context():
124                masks = self._get_model_outputs(batched_inputs)
125                net_loss = self._compute_loss(y, masks)
126
127            backprop(net_loss)
128
129            self._iteration += 1
130
131            if self.logger is not None:
132                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
133                self.logger.log_train(
134                    self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False
135                )
136
137            if self._iteration >= self.max_iteration:
138                break
139            progress.update(1)
140
141        t_per_iter = (time.time() - t_per_iter)
142        return t_per_iter
143
144    def _validate_impl(self, forward_context):
145        self.model.eval()
146
147        metric_val, loss_val = 0.0, 0.0
148
149        with torch.no_grad():
150            for x, y in self.val_loader:
151                batched_inputs = self.convert_inputs(x, y)
152
153                with forward_context():
154                    masks = self._get_model_outputs(batched_inputs)
155                    net_loss = self._compute_loss(y, masks)
156
157                loss_val += net_loss.item()
158                metric_val += net_loss.item()
159
160        loss_val /= len(self.val_loader)
161        metric_val /= len(self.val_loader)
162        dice_metric = 1 - (metric_val / self.num_classes)
163        print()
164        print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")
165
166        if self.logger is not None:
167            self.logger.log_validation(
168                self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)
169            )
170
171        return metric_val

Trainer class for training the Segment Anything model for semantic segmentation.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSemanticSamInputs can be used here.
  • num_classes: The number of classes for semantic segmentation (including the background class).
  • dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SemanticSamTrainer( convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs)
58    def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs):
59        assert num_classes > 1
60
61        if "loss" not in kwargs:
62            kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
63
64        if "metric" not in kwargs:
65            kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
66
67        super().__init__(**kwargs)
68
69        self.convert_inputs = convert_inputs
70        self.num_classes = num_classes
71        self.compute_ce_loss = nn.CrossEntropyLoss()
72        self.dice_weight = dice_weight
73
74        if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1):
75            raise ValueError("The weight factor should lie between 0 and 1.")
76
77        self._kwargs = kwargs
convert_inputs
num_classes
compute_ce_loss
dice_weight
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit
class SemanticMapsSamTrainer(SemanticSamTrainer):
174class SemanticMapsSamTrainer(SemanticSamTrainer):
175    def _compute_loss(self, y, masks):
176        target = y.to(self.device, non_blocking=True)
177
178        # Compute loss for the predictions
179        net_loss = self.loss(target, masks)
180
181        return net_loss

Trainer class for training the Segment Anything model for semantic segmentation.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSemanticSamInputs can be used here.
  • num_classes: The number of classes for semantic segmentation (including the background class).
  • dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
Inherited Members
SemanticSamTrainer
SemanticSamTrainer
convert_inputs
num_classes
compute_ce_loss
dice_weight
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit