micro_sam.training.semantic_sam_trainer
1import time 2from typing import Optional 3 4import torch 5import torch.nn as nn 6 7from torch_em.loss import DiceLoss 8from torch_em.trainer import DefaultTrainer 9 10 11class CustomDiceLoss(nn.Module): 12 """Loss for computing dice over one-hot labels. 13 14 Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. 15 16 Args: 17 num_classes: The number of classes for semantic segmentation (including background class). 18 softmax: Whether to use softmax over the predictions. 19 """ 20 def __init__(self, num_classes: int, softmax: bool = True) -> None: 21 super().__init__() 22 self.num_classes = num_classes 23 self.dice_loss = DiceLoss() 24 self.softmax = softmax 25 26 def _one_hot_encoder(self, input_tensor): 27 tensor_list = [] 28 for i in range(self.num_classes): 29 temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 30 tensor_list.append(temp_prob) 31 output_tensor = torch.cat(tensor_list, dim=1) 32 return output_tensor.float() 33 34 def __call__(self, pred, target): 35 if self.softmax: 36 pred = torch.softmax(pred, dim=1) 37 target = self._one_hot_encoder(target) 38 loss = self.dice_loss(pred, target) 39 return loss 40 41 42class SemanticSamTrainer(DefaultTrainer): 43 """Trainer class for training the Segment Anything model for semantic segmentation. 44 45 This class is derived from `torch_em.trainer.DefaultTrainer`. 46 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 47 for details on its usage and implementation. 48 49 Args: 50 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 51 The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. 52 num_classes: The number of classes for semantic segmentation (including the background class). 53 dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. 54 kwargs: The keyword arguments of the DefaultTrainer super class. 55 """ 56 57 def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs): 58 assert num_classes > 1 59 60 if "loss" not in kwargs: 61 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 62 63 if "metric" not in kwargs: 64 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 65 66 super().__init__(**kwargs) 67 68 self.convert_inputs = convert_inputs 69 self.num_classes = num_classes 70 self.compute_ce_loss = nn.CrossEntropyLoss() 71 self.dice_weight = dice_weight 72 73 if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1): 74 raise ValueError("The weight factor should lie between 0 and 1.") 75 76 self._kwargs = kwargs 77 78 def _compute_loss(self, y, masks): 79 """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. 80 """ 81 target = y.to(self.device, non_blocking=True) 82 # Compute dice loss for the predictions 83 dice_loss = self.loss(masks, target) 84 85 # Compute cross entropy loss for the predictions 86 ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) 87 88 if self.dice_weight is None: 89 net_loss = dice_loss + ce_loss 90 else: 91 net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss 92 93 return net_loss 94 95 def _get_model_outputs(self, batched_inputs): 96 """Get the predictions from the model. 97 """ 98 # Precompute the image embeddings if the model exposes it as functionality. 99 if hasattr(self.model, "image_embeddings_oft"): 100 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 101 batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) 102 else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. 103 # We need to take care of sending things to the device here. 104 batched_inputs = [ 105 {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} 106 for inp in batched_inputs 107 ] 108 batched_outputs = self.model(batched_inputs, multimask_output=True) 109 110 masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) 111 return masks 112 113 def _train_epoch_impl(self, progress, forward_context, backprop): 114 self.model.train() 115 116 t_per_iter = time.time() 117 for x, y in self.train_loader: 118 self.optimizer.zero_grad() 119 120 batched_inputs = self.convert_inputs(x, y) 121 122 with forward_context(): 123 masks = self._get_model_outputs(batched_inputs) 124 net_loss = self._compute_loss(y, masks) 125 126 backprop(net_loss) 127 128 self._iteration += 1 129 130 if self.logger is not None: 131 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 132 self.logger.log_train( 133 self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False 134 ) 135 136 if self._iteration >= self.max_iteration: 137 break 138 progress.update(1) 139 140 t_per_iter = (time.time() - t_per_iter) 141 return t_per_iter 142 143 def _validate_impl(self, forward_context): 144 self.model.eval() 145 146 metric_val, loss_val = 0.0, 0.0 147 148 with torch.no_grad(): 149 for x, y in self.val_loader: 150 batched_inputs = self.convert_inputs(x, y) 151 152 with forward_context(): 153 masks = self._get_model_outputs(batched_inputs) 154 net_loss = self._compute_loss(y, masks) 155 156 loss_val += net_loss.item() 157 metric_val += net_loss.item() 158 159 loss_val /= len(self.val_loader) 160 metric_val /= len(self.val_loader) 161 dice_metric = 1 - (metric_val / self.num_classes) 162 print() 163 print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") 164 165 if self.logger is not None: 166 self.logger.log_validation( 167 self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1) 168 ) 169 170 return metric_val 171 172 173class SemanticMapsSamTrainer(SemanticSamTrainer): 174 def _compute_loss(self, y, masks): 175 target = y.to(self.device, non_blocking=True) 176 177 # Compute loss for the predictions 178 net_loss = self.loss(target, masks) 179 180 return net_loss
12class CustomDiceLoss(nn.Module): 13 """Loss for computing dice over one-hot labels. 14 15 Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. 16 17 Args: 18 num_classes: The number of classes for semantic segmentation (including background class). 19 softmax: Whether to use softmax over the predictions. 20 """ 21 def __init__(self, num_classes: int, softmax: bool = True) -> None: 22 super().__init__() 23 self.num_classes = num_classes 24 self.dice_loss = DiceLoss() 25 self.softmax = softmax 26 27 def _one_hot_encoder(self, input_tensor): 28 tensor_list = [] 29 for i in range(self.num_classes): 30 temp_prob = input_tensor == i # * torch.ones_like(input_tensor) 31 tensor_list.append(temp_prob) 32 output_tensor = torch.cat(tensor_list, dim=1) 33 return output_tensor.float() 34 35 def __call__(self, pred, target): 36 if self.softmax: 37 pred = torch.softmax(pred, dim=1) 38 target = self._one_hot_encoder(target) 39 loss = self.dice_loss(pred, target) 40 return loss
Loss for computing dice over one-hot labels.
Expects prediction and target with num_classes
channels: the number of classes for semantic segmentation.
Arguments:
- num_classes: The number of classes for semantic segmentation (including background class).
- softmax: Whether to use softmax over the predictions.
21 def __init__(self, num_classes: int, softmax: bool = True) -> None: 22 super().__init__() 23 self.num_classes = num_classes 24 self.dice_loss = DiceLoss() 25 self.softmax = softmax
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- forward
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
43class SemanticSamTrainer(DefaultTrainer): 44 """Trainer class for training the Segment Anything model for semantic segmentation. 45 46 This class is derived from `torch_em.trainer.DefaultTrainer`. 47 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 48 for details on its usage and implementation. 49 50 Args: 51 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 52 The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. 53 num_classes: The number of classes for semantic segmentation (including the background class). 54 dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. 55 kwargs: The keyword arguments of the DefaultTrainer super class. 56 """ 57 58 def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs): 59 assert num_classes > 1 60 61 if "loss" not in kwargs: 62 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 63 64 if "metric" not in kwargs: 65 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 66 67 super().__init__(**kwargs) 68 69 self.convert_inputs = convert_inputs 70 self.num_classes = num_classes 71 self.compute_ce_loss = nn.CrossEntropyLoss() 72 self.dice_weight = dice_weight 73 74 if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1): 75 raise ValueError("The weight factor should lie between 0 and 1.") 76 77 self._kwargs = kwargs 78 79 def _compute_loss(self, y, masks): 80 """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. 81 """ 82 target = y.to(self.device, non_blocking=True) 83 # Compute dice loss for the predictions 84 dice_loss = self.loss(masks, target) 85 86 # Compute cross entropy loss for the predictions 87 ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) 88 89 if self.dice_weight is None: 90 net_loss = dice_loss + ce_loss 91 else: 92 net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss 93 94 return net_loss 95 96 def _get_model_outputs(self, batched_inputs): 97 """Get the predictions from the model. 98 """ 99 # Precompute the image embeddings if the model exposes it as functionality. 100 if hasattr(self.model, "image_embeddings_oft"): 101 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 102 batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) 103 else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. 104 # We need to take care of sending things to the device here. 105 batched_inputs = [ 106 {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} 107 for inp in batched_inputs 108 ] 109 batched_outputs = self.model(batched_inputs, multimask_output=True) 110 111 masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) 112 return masks 113 114 def _train_epoch_impl(self, progress, forward_context, backprop): 115 self.model.train() 116 117 t_per_iter = time.time() 118 for x, y in self.train_loader: 119 self.optimizer.zero_grad() 120 121 batched_inputs = self.convert_inputs(x, y) 122 123 with forward_context(): 124 masks = self._get_model_outputs(batched_inputs) 125 net_loss = self._compute_loss(y, masks) 126 127 backprop(net_loss) 128 129 self._iteration += 1 130 131 if self.logger is not None: 132 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 133 self.logger.log_train( 134 self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False 135 ) 136 137 if self._iteration >= self.max_iteration: 138 break 139 progress.update(1) 140 141 t_per_iter = (time.time() - t_per_iter) 142 return t_per_iter 143 144 def _validate_impl(self, forward_context): 145 self.model.eval() 146 147 metric_val, loss_val = 0.0, 0.0 148 149 with torch.no_grad(): 150 for x, y in self.val_loader: 151 batched_inputs = self.convert_inputs(x, y) 152 153 with forward_context(): 154 masks = self._get_model_outputs(batched_inputs) 155 net_loss = self._compute_loss(y, masks) 156 157 loss_val += net_loss.item() 158 metric_val += net_loss.item() 159 160 loss_val /= len(self.val_loader) 161 metric_val /= len(self.val_loader) 162 dice_metric = 1 - (metric_val / self.num_classes) 163 print() 164 print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") 165 166 if self.logger is not None: 167 self.logger.log_validation( 168 self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1) 169 ) 170 171 return metric_val
Trainer class for training the Segment Anything model for semantic segmentation.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSemanticSamInputs
can be used here. - num_classes: The number of classes for semantic segmentation (including the background class).
- dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
- kwargs: The keyword arguments of the DefaultTrainer super class.
58 def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float] = None, **kwargs): 59 assert num_classes > 1 60 61 if "loss" not in kwargs: 62 kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) 63 64 if "metric" not in kwargs: 65 kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) 66 67 super().__init__(**kwargs) 68 69 self.convert_inputs = convert_inputs 70 self.num_classes = num_classes 71 self.compute_ce_loss = nn.CrossEntropyLoss() 72 self.dice_weight = dice_weight 73 74 if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1): 75 raise ValueError("The weight factor should lie between 0 and 1.") 76 77 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit
174class SemanticMapsSamTrainer(SemanticSamTrainer): 175 def _compute_loss(self, y, masks): 176 target = y.to(self.device, non_blocking=True) 177 178 # Compute loss for the predictions 179 net_loss = self.loss(target, masks) 180 181 return net_loss
Trainer class for training the Segment Anything model for semantic segmentation.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSemanticSamInputs
can be used here. - num_classes: The number of classes for semantic segmentation (including the background class).
- dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function.
- kwargs: The keyword arguments of the DefaultTrainer super class.
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit