micro_sam.training.sam_trainer
1import os 2import time 3import random 4import warnings 5from typing import Optional 6 7import numpy as np 8 9import torch 10from torchvision.utils import make_grid 11 12import torch_em 13from torch_em.trainer.logger_base import TorchEmLogger 14 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator 16 17 18class SamTrainer(torch_em.trainer.DefaultTrainer): 19 """Trainer class for training the Segment Anything model. 20 21 This class is derived from `torch_em.trainer.DefaultTrainer`. 22 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 23 for details on its usage and implementation. 24 25 Args: 26 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 27 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 28 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 29 In each sub-iteration new point prompts are sampled where the model was wrong. 30 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 31 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 32 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 33 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training 34 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) 35 mask_loss: The loss to compare the predicted masks and the targets. 36 kwargs: The keyword arguments of the DefaultTrainer super class. 37 """ 38 39 def __init__( 40 self, 41 convert_inputs, 42 n_sub_iteration: int, 43 n_objects_per_batch: Optional[int] = None, 44 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 45 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 46 mask_prob: float = 0.5, 47 mask_loss: Optional[torch.nn.Module] = None, 48 **kwargs 49 ): 50 if mask_loss is None: 51 # We have to use the Dice Loss with reduce channel set to None. 52 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 53 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 54 else: 55 self.mask_loss = mask_loss 56 57 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 58 self.convert_inputs = convert_inputs 59 self.mse_loss = mse_loss 60 self.n_objects_per_batch = n_objects_per_batch 61 self.n_sub_iteration = n_sub_iteration 62 self.prompt_generator = prompt_generator 63 self.mask_prob = mask_prob 64 self._kwargs = kwargs 65 66 def _get_prompt_and_multimasking_choices(self, current_iteration): 67 """Choose the type of prompts we sample for training, and then we call 68 'convert_inputs' with the correct prompting from here. 69 """ 70 if current_iteration % 2 == 0: # sample only a single point per object 71 n_pos, n_neg = 1, 0 72 get_boxes = False 73 multimask_output = True 74 75 else: # sample only a single box per object 76 n_pos, n_neg = 0, 0 77 get_boxes = True 78 multimask_output = False 79 80 return n_pos, n_neg, get_boxes, multimask_output 81 82 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 83 """Choose the type of prompts we sample for validation, and then we call 84 'convert_inputs' with the correct prompting from here. 85 """ 86 if current_iteration % 4 == 0: # sample only a single point per object 87 n_pos, n_neg = 1, 0 88 get_boxes = False 89 multimask_output = True 90 91 elif current_iteration % 4 == 1: # sample only a single box per object 92 n_pos, n_neg = 0, 0 93 get_boxes = True 94 multimask_output = False 95 96 elif current_iteration % 4 == 2: # sample a random no. of points 97 pos_range, neg_range = 4, 4 98 99 n_pos = np.random.randint(1, pos_range + 1) 100 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 101 n_neg = np.random.randint(1, neg_range + 1) 102 else: 103 n_neg = np.random.randint(0, neg_range + 1) 104 get_boxes = False 105 multimask_output = False 106 107 else: # sample boxes AND random no. of points 108 # here we can have (1, 0) because we also have box 109 pos_range, neg_range = 4, 4 110 111 n_pos = np.random.randint(1, pos_range + 1) 112 n_neg = np.random.randint(0, neg_range + 1) 113 get_boxes = True 114 multimask_output = False 115 116 return n_pos, n_neg, get_boxes, multimask_output 117 118 def _compute_iou(self, pred, true, eps=1e-7): 119 """Compute the IoU score between the prediction and target. 120 """ 121 pred_mask = pred > 0.5 # binarizing the output predictions 122 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 123 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 124 iou = overlap / (union + eps) 125 return iou 126 127 def _compute_loss(self, batched_outputs, y_one_hot): 128 """Compute the loss for one iteration. The loss is made up of two components: 129 - The mask loss: dice score between the predicted masks and targets. 130 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 131 """ 132 mask_loss, iou_regression_loss = 0.0, 0.0 133 134 # Loop over the batch. 135 for batch_output, targets in zip(batched_outputs, y_one_hot): 136 137 predicted_objects = torch.sigmoid(batch_output["masks"]) 138 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 139 # We swap the axes that go into the dice loss so that the object axis 140 # corresponds to the channel axes. This ensures that the dice is computed 141 # independetly per channel. We do not reduce the channel axis in the dice, 142 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 143 dice_scores = torch.stack([ 144 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 145 for i in range(predicted_objects.shape[1]) 146 ]) 147 dice_scores, _ = torch.min(dice_scores, dim=0) 148 149 # Compute the actual IOU between the predicted and true objects. 150 # The outer loop is for the 1 or 3 predicted masks per true object. 151 with torch.no_grad(): 152 true_iou = torch.stack([ 153 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 154 ]) 155 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 156 # the object axis is back in the first dimension. 157 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 158 159 mask_loss = mask_loss + torch.mean(dice_scores) 160 iou_regression_loss = iou_regression_loss + iou_score 161 162 loss = mask_loss + iou_regression_loss 163 164 return loss, mask_loss, iou_regression_loss 165 166 # 167 # Functionality for iterative prompting loss 168 # 169 170 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 171 # Batched mask and logit (low-res mask) predictions. 172 masks = torch.stack([m["masks"] for m in batched_outputs]) 173 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 174 175 # Determine the best IOU across the multi-object prediction axis 176 # and turn this into a mask we can use for indexing. 177 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 178 # for details on the indexing logic. 179 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 180 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 181 182 # Index the mask and logits with the best iou indices. 183 # Note that we squash the first two axes (batch x objects) into one when indexing. 184 # That's why we need to reshape bax into (batch x objects) using a view. 185 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 186 batch_size, n_objects = masks.shape[:2] 187 h, w = masks.shape[-2:] 188 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 189 190 h, w = logits.shape[-2:] 191 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 192 193 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 194 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 195 masks = (masks > 0.0).float() 196 return masks, logits 197 198 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 199 """Compute the loss for several (sub-)iterations of iterative prompting. 200 In each iterations the prompts are updated based on the previous predictions. 201 """ 202 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 203 204 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 205 206 for i in range(0, num_subiter): 207 # We do multimasking only in the first sub-iteration as we then pass single prompt 208 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 209 batched_outputs = self.model( 210 batched_inputs=batched_inputs, 211 image_embeddings=image_embeddings, 212 multimask_output=multimask_output if i == 0 else False, 213 ) 214 215 # Compute loss for tis sub-iteration. 216 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 217 218 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 219 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 220 with torch.no_grad(): 221 net_mean_model_iou = torch.mean(batched_iou_predictions) 222 223 loss = loss + net_loss 224 mask_loss = mask_loss + net_mask_loss 225 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 226 mean_model_iou = mean_model_iou + net_mean_model_iou 227 228 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 229 # Determine the next prompts based on current predictions. 230 with torch.no_grad(): 231 # Get the mask and logit predictions corresponding to the predicted object 232 # (per actual object) with the best IOU. 233 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 234 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) 235 236 loss = loss / num_subiter 237 mask_loss = mask_loss / num_subiter 238 iou_regression_loss = iou_regression_loss / num_subiter 239 mean_model_iou = mean_model_iou / num_subiter 240 241 return loss, mask_loss, iou_regression_loss, mean_model_iou 242 243 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): 244 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 245 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 246 # here, we get each object in the pairs and do the point choices per-object 247 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 248 249 # convert the point coordinates to the expected resolution for iterative prompting 250 # NOTE: 251 # - "only" need to transform the point prompts from the iterative prompting 252 # - the `logits` are the low res masks (256, 256), hence do not need the transform 253 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 254 255 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 256 if "point_coords" in _inp.keys() else net_coords 257 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 258 if "point_labels" in _inp.keys() else net_labels 259 260 _inp["point_coords"] = updated_point_coords 261 _inp["point_labels"] = updated_point_labels 262 263 if self.mask_prob > 0: 264 # using mask inputs for iterative prompting while training, with a probability 265 use_mask_inputs = (random.random() < self.mask_prob) 266 if use_mask_inputs: 267 _inp["mask_inputs"] = logits 268 else: # remove previously existing mask inputs to avoid using them in next sub-iteration 269 _inp.pop("mask_inputs", None) 270 271 return batched_inputs 272 273 # 274 # Training Loop 275 # 276 277 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 278 """Compute one hot target (one mask per channel) for the sampled ids 279 and restrict the number of sampled objects to the minimal number in the batch. 280 """ 281 assert len(y) == len(sampled_ids) 282 283 # Get the minimal number of objects in this batch. 284 # The number of objects in a patch might be < n_objects_per_batch. 285 # This is why we need to restrict it here to ensure the same 286 # number of objects across the batch. 287 n_objects = min(len(ids) for ids in sampled_ids) 288 289 y = y.to(self.device, non_blocking=True) 290 # Compute the one hot targets for the seg-id. 291 y_one_hot = torch.stack([ 292 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 293 for target, ids in zip(y, sampled_ids) 294 ]).float() 295 296 # Also restrict the prompts to the number of objects. 297 batched_inputs = [ 298 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 299 for inp in batched_inputs 300 ] 301 return batched_inputs, y_one_hot 302 303 def _interactive_train_iteration(self, x, y): 304 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 305 306 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 307 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 308 309 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 310 batched_inputs=batched_inputs, 311 y_one_hot=y_one_hot, 312 num_subiter=self.n_sub_iteration, 313 multimask_output=multimask_output 314 ) 315 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 316 317 def _check_input_normalization(self, x, input_check_done): 318 # The expected data range of the SAM model is 8bit (0-255). 319 # It can easily happen that data is normalized beforehand in training. 320 # For some reasons we don't fully understand this still works, but it 321 # should still be avoided and is very detrimental in some settings 322 # (e.g. when freezing the image encoder) 323 # We check once per epoch if the data seems to be normalized already and 324 # raise a warning if this is the case. 325 if not input_check_done: 326 data_min, data_max = x.min(), x.max() 327 if (data_min < 0) or (data_max < 1): 328 warnings.warn( 329 "It looks like you are normalizing the training data." 330 "The SAM model takes care of normalization, so it is better to not do this." 331 "We recommend to remove data normalization and input data in the range [0, 255]." 332 ) 333 input_check_done = True 334 335 return input_check_done 336 337 def _train_epoch_impl(self, progress, forward_context, backprop): 338 self.model.train() 339 340 input_check_done = False 341 342 n_iter = 0 343 t_per_iter = time.time() 344 for x, y in self.train_loader: 345 input_check_done = self._check_input_normalization(x, input_check_done) 346 347 self.optimizer.zero_grad() 348 349 with forward_context(): 350 (loss, mask_loss, iou_regression_loss, model_iou, 351 sampled_binary_y) = self._interactive_train_iteration(x, y) 352 353 backprop(loss) 354 355 if self.logger is not None: 356 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 357 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 358 self.logger.log_train( 359 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 360 ) 361 362 self._iteration += 1 363 n_iter += 1 364 if self._iteration >= self.max_iteration: 365 break 366 progress.update(1) 367 368 t_per_iter = (time.time() - t_per_iter) / n_iter 369 return t_per_iter 370 371 def _interactive_val_iteration(self, x, y, val_iteration): 372 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 373 374 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 375 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 376 377 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 378 379 batched_outputs = self.model( 380 batched_inputs=batched_inputs, 381 image_embeddings=image_embeddings, 382 multimask_output=multimask_output, 383 ) 384 385 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 386 # We use the dice loss over the masks as metric. 387 metric = mask_loss 388 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 389 390 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 391 392 def _validate_impl(self, forward_context): 393 self.model.eval() 394 395 input_check_done = False 396 397 val_iteration = 0 398 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 399 400 with torch.no_grad(): 401 for x, y in self.val_loader: 402 input_check_done = self._check_input_normalization(x, input_check_done) 403 404 with forward_context(): 405 (loss, mask_loss, iou_regression_loss, model_iou, 406 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 407 408 loss_val += loss.item() 409 metric_val += metric.item() 410 model_iou_val += model_iou.item() 411 val_iteration += 1 412 413 loss_val /= len(self.val_loader) 414 metric_val /= len(self.val_loader) 415 model_iou_val /= len(self.val_loader) 416 print() 417 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 418 419 if self.logger is not None: 420 self.logger.log_validation( 421 self._iteration, metric_val, loss_val, x, y, 422 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 423 ) 424 425 return metric_val 426 427 428class SamLogger(TorchEmLogger): 429 """@private""" 430 def __init__(self, trainer, save_root, **unused_kwargs): 431 super().__init__(trainer, save_root) 432 self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name) 433 os.makedirs(self.log_dir, exist_ok=True) 434 435 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 436 self.log_image_interval = trainer.log_image_interval 437 438 def add_image(self, x, y, samples, name, step): 439 self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) 440 self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) 441 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 442 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 443 444 def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): 445 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 446 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 447 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 448 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 449 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 450 if step % self.log_image_interval == 0: 451 self.add_image(x, y, samples, "train", step) 452 453 def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): 454 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 455 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 456 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 457 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 458 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 459 self.add_image(x, y, samples, "validation", step)
class
SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
19class SamTrainer(torch_em.trainer.DefaultTrainer): 20 """Trainer class for training the Segment Anything model. 21 22 This class is derived from `torch_em.trainer.DefaultTrainer`. 23 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 24 for details on its usage and implementation. 25 26 Args: 27 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 28 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 29 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 30 In each sub-iteration new point prompts are sampled where the model was wrong. 31 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 32 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 33 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 34 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training 35 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) 36 mask_loss: The loss to compare the predicted masks and the targets. 37 kwargs: The keyword arguments of the DefaultTrainer super class. 38 """ 39 40 def __init__( 41 self, 42 convert_inputs, 43 n_sub_iteration: int, 44 n_objects_per_batch: Optional[int] = None, 45 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 46 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 47 mask_prob: float = 0.5, 48 mask_loss: Optional[torch.nn.Module] = None, 49 **kwargs 50 ): 51 if mask_loss is None: 52 # We have to use the Dice Loss with reduce channel set to None. 53 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 54 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 55 else: 56 self.mask_loss = mask_loss 57 58 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 59 self.convert_inputs = convert_inputs 60 self.mse_loss = mse_loss 61 self.n_objects_per_batch = n_objects_per_batch 62 self.n_sub_iteration = n_sub_iteration 63 self.prompt_generator = prompt_generator 64 self.mask_prob = mask_prob 65 self._kwargs = kwargs 66 67 def _get_prompt_and_multimasking_choices(self, current_iteration): 68 """Choose the type of prompts we sample for training, and then we call 69 'convert_inputs' with the correct prompting from here. 70 """ 71 if current_iteration % 2 == 0: # sample only a single point per object 72 n_pos, n_neg = 1, 0 73 get_boxes = False 74 multimask_output = True 75 76 else: # sample only a single box per object 77 n_pos, n_neg = 0, 0 78 get_boxes = True 79 multimask_output = False 80 81 return n_pos, n_neg, get_boxes, multimask_output 82 83 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 84 """Choose the type of prompts we sample for validation, and then we call 85 'convert_inputs' with the correct prompting from here. 86 """ 87 if current_iteration % 4 == 0: # sample only a single point per object 88 n_pos, n_neg = 1, 0 89 get_boxes = False 90 multimask_output = True 91 92 elif current_iteration % 4 == 1: # sample only a single box per object 93 n_pos, n_neg = 0, 0 94 get_boxes = True 95 multimask_output = False 96 97 elif current_iteration % 4 == 2: # sample a random no. of points 98 pos_range, neg_range = 4, 4 99 100 n_pos = np.random.randint(1, pos_range + 1) 101 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 102 n_neg = np.random.randint(1, neg_range + 1) 103 else: 104 n_neg = np.random.randint(0, neg_range + 1) 105 get_boxes = False 106 multimask_output = False 107 108 else: # sample boxes AND random no. of points 109 # here we can have (1, 0) because we also have box 110 pos_range, neg_range = 4, 4 111 112 n_pos = np.random.randint(1, pos_range + 1) 113 n_neg = np.random.randint(0, neg_range + 1) 114 get_boxes = True 115 multimask_output = False 116 117 return n_pos, n_neg, get_boxes, multimask_output 118 119 def _compute_iou(self, pred, true, eps=1e-7): 120 """Compute the IoU score between the prediction and target. 121 """ 122 pred_mask = pred > 0.5 # binarizing the output predictions 123 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 124 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 125 iou = overlap / (union + eps) 126 return iou 127 128 def _compute_loss(self, batched_outputs, y_one_hot): 129 """Compute the loss for one iteration. The loss is made up of two components: 130 - The mask loss: dice score between the predicted masks and targets. 131 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 132 """ 133 mask_loss, iou_regression_loss = 0.0, 0.0 134 135 # Loop over the batch. 136 for batch_output, targets in zip(batched_outputs, y_one_hot): 137 138 predicted_objects = torch.sigmoid(batch_output["masks"]) 139 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 140 # We swap the axes that go into the dice loss so that the object axis 141 # corresponds to the channel axes. This ensures that the dice is computed 142 # independetly per channel. We do not reduce the channel axis in the dice, 143 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 144 dice_scores = torch.stack([ 145 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 146 for i in range(predicted_objects.shape[1]) 147 ]) 148 dice_scores, _ = torch.min(dice_scores, dim=0) 149 150 # Compute the actual IOU between the predicted and true objects. 151 # The outer loop is for the 1 or 3 predicted masks per true object. 152 with torch.no_grad(): 153 true_iou = torch.stack([ 154 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 155 ]) 156 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 157 # the object axis is back in the first dimension. 158 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 159 160 mask_loss = mask_loss + torch.mean(dice_scores) 161 iou_regression_loss = iou_regression_loss + iou_score 162 163 loss = mask_loss + iou_regression_loss 164 165 return loss, mask_loss, iou_regression_loss 166 167 # 168 # Functionality for iterative prompting loss 169 # 170 171 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 172 # Batched mask and logit (low-res mask) predictions. 173 masks = torch.stack([m["masks"] for m in batched_outputs]) 174 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 175 176 # Determine the best IOU across the multi-object prediction axis 177 # and turn this into a mask we can use for indexing. 178 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 179 # for details on the indexing logic. 180 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 181 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 182 183 # Index the mask and logits with the best iou indices. 184 # Note that we squash the first two axes (batch x objects) into one when indexing. 185 # That's why we need to reshape bax into (batch x objects) using a view. 186 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 187 batch_size, n_objects = masks.shape[:2] 188 h, w = masks.shape[-2:] 189 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 190 191 h, w = logits.shape[-2:] 192 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 193 194 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 195 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 196 masks = (masks > 0.0).float() 197 return masks, logits 198 199 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 200 """Compute the loss for several (sub-)iterations of iterative prompting. 201 In each iterations the prompts are updated based on the previous predictions. 202 """ 203 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 204 205 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 206 207 for i in range(0, num_subiter): 208 # We do multimasking only in the first sub-iteration as we then pass single prompt 209 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 210 batched_outputs = self.model( 211 batched_inputs=batched_inputs, 212 image_embeddings=image_embeddings, 213 multimask_output=multimask_output if i == 0 else False, 214 ) 215 216 # Compute loss for tis sub-iteration. 217 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 218 219 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 220 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 221 with torch.no_grad(): 222 net_mean_model_iou = torch.mean(batched_iou_predictions) 223 224 loss = loss + net_loss 225 mask_loss = mask_loss + net_mask_loss 226 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 227 mean_model_iou = mean_model_iou + net_mean_model_iou 228 229 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 230 # Determine the next prompts based on current predictions. 231 with torch.no_grad(): 232 # Get the mask and logit predictions corresponding to the predicted object 233 # (per actual object) with the best IOU. 234 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 235 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) 236 237 loss = loss / num_subiter 238 mask_loss = mask_loss / num_subiter 239 iou_regression_loss = iou_regression_loss / num_subiter 240 mean_model_iou = mean_model_iou / num_subiter 241 242 return loss, mask_loss, iou_regression_loss, mean_model_iou 243 244 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): 245 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 246 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 247 # here, we get each object in the pairs and do the point choices per-object 248 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 249 250 # convert the point coordinates to the expected resolution for iterative prompting 251 # NOTE: 252 # - "only" need to transform the point prompts from the iterative prompting 253 # - the `logits` are the low res masks (256, 256), hence do not need the transform 254 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 255 256 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 257 if "point_coords" in _inp.keys() else net_coords 258 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 259 if "point_labels" in _inp.keys() else net_labels 260 261 _inp["point_coords"] = updated_point_coords 262 _inp["point_labels"] = updated_point_labels 263 264 if self.mask_prob > 0: 265 # using mask inputs for iterative prompting while training, with a probability 266 use_mask_inputs = (random.random() < self.mask_prob) 267 if use_mask_inputs: 268 _inp["mask_inputs"] = logits 269 else: # remove previously existing mask inputs to avoid using them in next sub-iteration 270 _inp.pop("mask_inputs", None) 271 272 return batched_inputs 273 274 # 275 # Training Loop 276 # 277 278 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 279 """Compute one hot target (one mask per channel) for the sampled ids 280 and restrict the number of sampled objects to the minimal number in the batch. 281 """ 282 assert len(y) == len(sampled_ids) 283 284 # Get the minimal number of objects in this batch. 285 # The number of objects in a patch might be < n_objects_per_batch. 286 # This is why we need to restrict it here to ensure the same 287 # number of objects across the batch. 288 n_objects = min(len(ids) for ids in sampled_ids) 289 290 y = y.to(self.device, non_blocking=True) 291 # Compute the one hot targets for the seg-id. 292 y_one_hot = torch.stack([ 293 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 294 for target, ids in zip(y, sampled_ids) 295 ]).float() 296 297 # Also restrict the prompts to the number of objects. 298 batched_inputs = [ 299 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 300 for inp in batched_inputs 301 ] 302 return batched_inputs, y_one_hot 303 304 def _interactive_train_iteration(self, x, y): 305 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 306 307 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 308 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 309 310 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 311 batched_inputs=batched_inputs, 312 y_one_hot=y_one_hot, 313 num_subiter=self.n_sub_iteration, 314 multimask_output=multimask_output 315 ) 316 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 317 318 def _check_input_normalization(self, x, input_check_done): 319 # The expected data range of the SAM model is 8bit (0-255). 320 # It can easily happen that data is normalized beforehand in training. 321 # For some reasons we don't fully understand this still works, but it 322 # should still be avoided and is very detrimental in some settings 323 # (e.g. when freezing the image encoder) 324 # We check once per epoch if the data seems to be normalized already and 325 # raise a warning if this is the case. 326 if not input_check_done: 327 data_min, data_max = x.min(), x.max() 328 if (data_min < 0) or (data_max < 1): 329 warnings.warn( 330 "It looks like you are normalizing the training data." 331 "The SAM model takes care of normalization, so it is better to not do this." 332 "We recommend to remove data normalization and input data in the range [0, 255]." 333 ) 334 input_check_done = True 335 336 return input_check_done 337 338 def _train_epoch_impl(self, progress, forward_context, backprop): 339 self.model.train() 340 341 input_check_done = False 342 343 n_iter = 0 344 t_per_iter = time.time() 345 for x, y in self.train_loader: 346 input_check_done = self._check_input_normalization(x, input_check_done) 347 348 self.optimizer.zero_grad() 349 350 with forward_context(): 351 (loss, mask_loss, iou_regression_loss, model_iou, 352 sampled_binary_y) = self._interactive_train_iteration(x, y) 353 354 backprop(loss) 355 356 if self.logger is not None: 357 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 358 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 359 self.logger.log_train( 360 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 361 ) 362 363 self._iteration += 1 364 n_iter += 1 365 if self._iteration >= self.max_iteration: 366 break 367 progress.update(1) 368 369 t_per_iter = (time.time() - t_per_iter) / n_iter 370 return t_per_iter 371 372 def _interactive_val_iteration(self, x, y, val_iteration): 373 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 374 375 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 376 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 377 378 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 379 380 batched_outputs = self.model( 381 batched_inputs=batched_inputs, 382 image_embeddings=image_embeddings, 383 multimask_output=multimask_output, 384 ) 385 386 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 387 # We use the dice loss over the masks as metric. 388 metric = mask_loss 389 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 390 391 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 392 393 def _validate_impl(self, forward_context): 394 self.model.eval() 395 396 input_check_done = False 397 398 val_iteration = 0 399 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 400 401 with torch.no_grad(): 402 for x, y in self.val_loader: 403 input_check_done = self._check_input_normalization(x, input_check_done) 404 405 with forward_context(): 406 (loss, mask_loss, iou_regression_loss, model_iou, 407 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 408 409 loss_val += loss.item() 410 metric_val += metric.item() 411 model_iou_val += model_iou.item() 412 val_iteration += 1 413 414 loss_val /= len(self.val_loader) 415 metric_val /= len(self.val_loader) 416 model_iou_val /= len(self.val_loader) 417 print() 418 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 419 420 if self.logger is not None: 421 self.logger.log_validation( 422 self._iteration, metric_val, loss_val, x, y, 423 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 424 ) 425 426 return metric_val
Trainer class for training the Segment Anything model.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSamInputs
can be used here. - n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
- n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
- mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
- prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
- mask_prob: The probability of using the mask inputs in the iterative prompting (per
n_sub_iteration
) - mask_loss: The loss to compare the predicted masks and the targets.
- kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
40 def __init__( 41 self, 42 convert_inputs, 43 n_sub_iteration: int, 44 n_objects_per_batch: Optional[int] = None, 45 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 46 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 47 mask_prob: float = 0.5, 48 mask_loss: Optional[torch.nn.Module] = None, 49 **kwargs 50 ): 51 if mask_loss is None: 52 # We have to use the Dice Loss with reduce channel set to None. 53 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 54 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 55 else: 56 self.mask_loss = mask_loss 57 58 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 59 self.convert_inputs = convert_inputs 60 self.mse_loss = mse_loss 61 self.n_objects_per_batch = n_objects_per_batch 62 self.n_sub_iteration = n_sub_iteration 63 self.prompt_generator = prompt_generator 64 self.mask_prob = mask_prob 65 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit