micro_sam.training.semantic_sam_trainer
1import time 2from typing import Optional 3 4import torch 5import torch.nn as nn 6 7from torch_em.loss import DiceLoss 8from torch_em.trainer import DefaultTrainer 9 10 11class CustomDiceLoss(nn.Module): 12 """Loss for computing dice over one-hot labels. 13 14 Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. 15 16 Args: 17 num_classes: The number of classes for semantic segmentation (including background class). 18 softmax: Whether to use softmax over the predictions. 19 """ 20 def __init__(self, num_classes: int, softmax: bool = True) -> None: 21 super().__init__() 22 self.num_classes = num_classes 23 self.dice_loss = DiceLoss() 24 self.softmax = softmax 25 26 def _one_hot_encoder(self, input_tensor): 27 tensor_list = [] 28 for i in range(self.num_classes): 29 temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 30 tensor_list.append(temp_prob) 31 output_tensor = torch.cat(tensor_list, dim=1) 32 return output_tensor.float() 33 34 def __call__(self, pred, target): 35 if self.softmax: 36 pred = torch.softmax(pred, dim=1) 37 target = self._one_hot_encoder(target) 38 loss = self.dice_loss(pred, target) 39 return loss 40 41 42class SemanticSamTrainer(DefaultTrainer): 43 """Trainer class for training the Segment Anything model for semantic segmentation. 44 45 This class is derived from `torch_em.trainer.DefaultTrainer`. 46 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 47 for details on its usage and implementation. 48 49 Args: 50 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 51 The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. 52 num_classes: The number of classes for semantic segmentation (including the background class). 53 dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. 54 kwargs: The keyword arguments of the DefaultTrainer super class. 55 """ 56 def __init__( 57 self, 58 convert_inputs, 59 num_classes: int, 60 dice_weight: Optional[float] = None, 61 **kwargs 62 ): 63 assert num_classes > 1 64 65 if "loss" not in kwargs: 66 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 67 68 if "metric" not in kwargs: 69 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 70 71 super().__init__(**kwargs) 72 73 self.convert_inputs = convert_inputs 74 self.num_classes = num_classes 75 self.compute_ce_loss = nn.CrossEntropyLoss() 76 self.dice_weight = dice_weight 77 78 if self.dice_weight is not None: 79 assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." 80 81 self._kwargs = kwargs 82 83 def _compute_loss(self, y, masks): 84 """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. 85 """ 86 target = y.to(self.device, non_blocking=True) 87 # Compute dice loss for the predictions 88 dice_loss = self.loss(masks, target) 89 90 # Compute cross entropy loss for the predictions 91 ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) 92 93 if self.dice_weight is None: 94 net_loss = dice_loss + ce_loss 95 else: 96 net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss 97 98 return net_loss 99 100 def _get_model_outputs(self, batched_inputs): 101 """Get the predictions from the model. 102 """ 103 # Precompute the image embeddings if the model exposes it as functionality. 104 if hasattr(self.model, "image_embeddings_oft"): 105 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 106 batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) 107 else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. 108 # We need to take care of sending things to the device here. 109 batched_inputs = [ 110 {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} 111 for inp in batched_inputs 112 ] 113 batched_outputs = self.model(batched_inputs, multimask_output=True) 114 115 masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) 116 return masks 117 118 def _train_epoch_impl(self, progress, forward_context, backprop): 119 self.model.train() 120 121 t_per_iter = time.time() 122 for x, y in self.train_loader: 123 self.optimizer.zero_grad() 124 125 batched_inputs = self.convert_inputs(x, y) 126 127 with forward_context(): 128 masks = self._get_model_outputs(batched_inputs) 129 net_loss = self._compute_loss(y, masks) 130 131 backprop(net_loss) 132 133 self._iteration += 1 134 135 if self.logger is not None: 136 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 137 self.logger.log_train( 138 self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False 139 ) 140 141 if self._iteration >= self.max_iteration: 142 break 143 progress.update(1) 144 145 t_per_iter = (time.time() - t_per_iter) 146 return t_per_iter 147 148 def _validate_impl(self, forward_context): 149 self.model.eval() 150 151 metric_val, loss_val = 0.0, 0.0 152 153 with torch.no_grad(): 154 for x, y in self.val_loader: 155 batched_inputs = self.convert_inputs(x, y) 156 157 with forward_context(): 158 masks = self._get_model_outputs(batched_inputs) 159 net_loss = self._compute_loss(y, masks) 160 161 loss_val += net_loss.item() 162 metric_val += net_loss.item() 163 164 loss_val /= len(self.val_loader) 165 metric_val /= len(self.val_loader) 166 dice_metric = 1 - (metric_val / self.num_classes) 167 print() 168 print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") 169 170 if self.logger is not None: 171 self.logger.log_validation( 172 self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1) 173 ) 174 175 return metric_val 176 177 178class SemanticMapsSamTrainer(SemanticSamTrainer): 179 def _compute_loss(self, y, masks): 180 target = y.to(self.device, non_blocking=True) 181 182 # Compute loss for the predictions 183 net_loss = self.loss(target, masks) 184 185 return net_loss
12class CustomDiceLoss(nn.Module): 13 """Loss for computing dice over one-hot labels. 14 15 Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. 16 17 Args: 18 num_classes: The number of classes for semantic segmentation (including background class). 19 softmax: Whether to use softmax over the predictions. 20 """ 21 def __init__(self, num_classes: int, softmax: bool = True) -> None: 22 super().__init__() 23 self.num_classes = num_classes 24 self.dice_loss = DiceLoss() 25 self.softmax = softmax 26 27 def _one_hot_encoder(self, input_tensor): 28 tensor_list = [] 29 for i in range(self.num_classes): 30 temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 31 tensor_list.append(temp_prob) 32 output_tensor = torch.cat(tensor_list, dim=1) 33 return output_tensor.float() 34 35 def __call__(self, pred, target): 36 if self.softmax: 37 pred = torch.softmax(pred, dim=1) 38 target = self._one_hot_encoder(target) 39 loss = self.dice_loss(pred, target) 40 return loss
Loss for computing dice over one-hot labels.
Expects prediction and target with num_classes
channels: the number of classes for semantic segmentation.
Arguments:
- num_classes: The number of classes for semantic segmentation (including background class).
- softmax: Whether to use softmax over the predictions.
21 def __init__(self, num_classes: int, softmax: bool = True) -> None: 22 super().__init__() 23 self.num_classes = num_classes 24 self.dice_loss = DiceLoss() 25 self.softmax = softmax
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
43class SemanticSamTrainer(DefaultTrainer): 44 """Trainer class for training the Segment Anything model for semantic segmentation. 45 46 This class is derived from `torch_em.trainer.DefaultTrainer`. 47 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 48 for details on its usage and implementation. 49 50 Args: 51 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 52 The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. 53 num_classes: The number of classes for semantic segmentation (including the background class). 54 dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. 55 kwargs: The keyword arguments of the DefaultTrainer super class. 56 """ 57 def __init__( 58 self, 59 convert_inputs, 60 num_classes: int, 61 dice_weight: Optional[float] = None, 62 **kwargs 63 ): 64 assert num_classes > 1 65 66 if "loss" not in kwargs: 67 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 68 69 if "metric" not in kwargs: 70 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 71 72 super().__init__(**kwargs) 73 74 self.convert_inputs = convert_inputs 75 self.num_classes = num_classes 76 self.compute_ce_loss = nn.CrossEntropyLoss() 77 self.dice_weight = dice_weight 78 79 if self.dice_weight is not None: 80 assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." 81 82 self._kwargs = kwargs 83 84 def _compute_loss(self, y, masks): 85 """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. 86 """ 87 target = y.to(self.device, non_blocking=True) 88 # Compute dice loss for the predictions 89 dice_loss = self.loss(masks, target) 90 91 # Compute cross entropy loss for the predictions 92 ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) 93 94 if self.dice_weight is None: 95 net_loss = dice_loss + ce_loss 96 else: 97 net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss 98 99 return net_loss 100 101 def _get_model_outputs(self, batched_inputs): 102 """Get the predictions from the model. 103 """ 104 # Precompute the image embeddings if the model exposes it as functionality. 105 if hasattr(self.model, "image_embeddings_oft"): 106 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 107 batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) 108 else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. 109 # We need to take care of sending things to the device here. 110 batched_inputs = [ 111 {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} 112 for inp in batched_inputs 113 ] 114 batched_outputs = self.model(batched_inputs, multimask_output=True) 115 116 masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) 117 return masks 118 119 def _train_epoch_impl(self, progress, forward_context, backprop): 120 self.model.train() 121 122 t_per_iter = time.time() 123 for x, y in self.train_loader: 124 self.optimizer.zero_grad() 125 126 batched_inputs = self.convert_inputs(x, y) 127 128 with forward_context(): 129 masks = self._get_model_outputs(batched_inputs) 130 net_loss = self._compute_loss(y, masks) 131 132 backprop(net_loss) 133 134 self._iteration += 1 135 136 if self.logger is not None: 137 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 138 self.logger.log_train( 139 self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False 140 ) 141 142 if self._iteration >= self.max_iteration: 143 break 144 progress.update(1) 145 146 t_per_iter = (time.time() - t_per_iter) 147 return t_per_iter 148 149 def _validate_impl(self, forward_context): 150 self.model.eval() 151 152 metric_val, loss_val = 0.0, 0.0 153 154 with torch.no_grad(): 155 for x, y in self.val_loader: 156 batched_inputs = self.convert_inputs(x, y) 157 158 with forward_context(): 159 masks = self._get_model_outputs(batched_inputs) 160 net_loss = self._compute_loss(y, masks) 161 162 loss_val += net_loss.item() 163 metric_val += net_loss.item() 164 165 loss_val /= len(self.val_loader) 166 metric_val /= len(self.val_loader) 167 dice_metric = 1 - (metric_val / self.num_classes) 168 print() 169 print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") 170 171 if self.logger is not None: 172 self.logger.log_validation( 173 self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1) 174 ) 175 176 return metric_val
Trainer class for training the Segment Anything model for semantic segmentation.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSemanticSamInputs
can be used here. - num_classes: The number of classes for semantic segmentation (including the background class).
- dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
- kwargs: The keyword arguments of the DefaultTrainer super class.
57 def __init__( 58 self, 59 convert_inputs, 60 num_classes: int, 61 dice_weight: Optional[float] = None, 62 **kwargs 63 ): 64 assert num_classes > 1 65 66 if "loss" not in kwargs: 67 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 68 69 if "metric" not in kwargs: 70 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 71 72 super().__init__(**kwargs) 73 74 self.convert_inputs = convert_inputs 75 self.num_classes = num_classes 76 self.compute_ce_loss = nn.CrossEntropyLoss() 77 self.dice_weight = dice_weight 78 79 if self.dice_weight is not None: 80 assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." 81 82 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit
179class SemanticMapsSamTrainer(SemanticSamTrainer): 180 def _compute_loss(self, y, masks): 181 target = y.to(self.device, non_blocking=True) 182 183 # Compute loss for the predictions 184 net_loss = self.loss(target, masks) 185 186 return net_loss
Trainer class for training the Segment Anything model for semantic segmentation.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSemanticSamInputs
can be used here. - num_classes: The number of classes for semantic segmentation (including the background class).
- dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
- kwargs: The keyword arguments of the DefaultTrainer super class.
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit