micro_sam.training.sam_trainer
1import os 2import time 3import random 4import warnings 5from typing import Optional 6 7import numpy as np 8 9import torch 10from torchvision.utils import make_grid 11 12import torch_em 13from torch_em.trainer.logger_base import TorchEmLogger 14 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator 16 17 18class SamTrainer(torch_em.trainer.DefaultTrainer): 19 """Trainer class for training the Segment Anything model. 20 21 This class is derived from `torch_em.trainer.DefaultTrainer`. 22 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 23 for details on its usage and implementation. 24 25 Args: 26 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 27 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 28 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 29 In each sub-iteration new point prompts are sampled where the model was wrong. 30 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 31 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 32 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 33 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training 34 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) 35 mask_loss: The loss to compare the predicted masks and the targets. 36 kwargs: The keyword arguments of the DefaultTrainer super class. 37 """ 38 39 def __init__( 40 self, 41 convert_inputs, 42 n_sub_iteration: int, 43 n_objects_per_batch: Optional[int] = None, 44 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 45 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 46 mask_prob: float = 0.5, 47 mask_loss: Optional[torch.nn.Module] = None, 48 **kwargs 49 ): 50 if mask_loss is None: 51 # We have to use the Dice Loss with reduce channel set to None. 52 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 53 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 54 else: 55 self.mask_loss = mask_loss 56 57 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 58 self.convert_inputs = convert_inputs 59 self.mse_loss = mse_loss 60 self.n_objects_per_batch = n_objects_per_batch 61 self.n_sub_iteration = n_sub_iteration 62 self.prompt_generator = prompt_generator 63 self.mask_prob = mask_prob 64 self._kwargs = kwargs 65 66 def _get_prompt_and_multimasking_choices(self, current_iteration): 67 """Choose the type of prompts we sample for training, and then we call 68 'convert_inputs' with the correct prompting from here. 69 """ 70 if current_iteration % 2 == 0: # sample only a single point per object 71 n_pos, n_neg = 1, 0 72 get_boxes = False 73 multimask_output = True 74 75 else: # sample only a single box per object 76 n_pos, n_neg = 0, 0 77 get_boxes = True 78 multimask_output = False 79 80 return n_pos, n_neg, get_boxes, multimask_output 81 82 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 83 """Choose the type of prompts we sample for validation, and then we call 84 'convert_inputs' with the correct prompting from here. 85 """ 86 if current_iteration % 4 == 0: # sample only a single point per object 87 n_pos, n_neg = 1, 0 88 get_boxes = False 89 multimask_output = True 90 91 elif current_iteration % 4 == 1: # sample only a single box per object 92 n_pos, n_neg = 0, 0 93 get_boxes = True 94 multimask_output = False 95 96 elif current_iteration % 4 == 2: # sample a random no. of points 97 pos_range, neg_range = 4, 4 98 99 n_pos = np.random.randint(1, pos_range + 1) 100 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 101 n_neg = np.random.randint(1, neg_range + 1) 102 else: 103 n_neg = np.random.randint(0, neg_range + 1) 104 get_boxes = False 105 multimask_output = False 106 107 else: # sample boxes AND random no. of points 108 # here we can have (1, 0) because we also have box 109 pos_range, neg_range = 4, 4 110 111 n_pos = np.random.randint(1, pos_range + 1) 112 n_neg = np.random.randint(0, neg_range + 1) 113 get_boxes = True 114 multimask_output = False 115 116 return n_pos, n_neg, get_boxes, multimask_output 117 118 def _compute_iou(self, pred, true, eps=1e-7): 119 """Compute the IoU score between the prediction and target. 120 """ 121 pred_mask = pred > 0.5 # binarizing the output predictions 122 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 123 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 124 iou = overlap / (union + eps) 125 return iou 126 127 def _compute_loss(self, batched_outputs, y_one_hot): 128 """Compute the loss for one iteration. The loss is made up of two components: 129 - The mask loss: dice score between the predicted masks and targets. 130 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 131 """ 132 mask_loss, iou_regression_loss = 0.0, 0.0 133 134 # Loop over the batch. 135 for batch_output, targets in zip(batched_outputs, y_one_hot): 136 137 predicted_objects = torch.sigmoid(batch_output["masks"]) 138 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 139 # We swap the axes that go into the dice loss so that the object axis 140 # corresponds to the channel axes. This ensures that the dice is computed 141 # independetly per channel. We do not reduce the channel axis in the dice, 142 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 143 dice_scores = torch.stack([ 144 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 145 for i in range(predicted_objects.shape[1]) 146 ]) 147 dice_scores, _ = torch.min(dice_scores, dim=0) 148 149 # Compute the actual IOU between the predicted and true objects. 150 # The outer loop is for the 1 or 3 predicted masks per true object. 151 with torch.no_grad(): 152 true_iou = torch.stack([ 153 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 154 ]) 155 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 156 # the object axis is back in the first dimension. 157 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 158 159 mask_loss = mask_loss + torch.mean(dice_scores) 160 iou_regression_loss = iou_regression_loss + iou_score 161 162 loss = mask_loss + iou_regression_loss 163 164 return loss, mask_loss, iou_regression_loss 165 166 # 167 # Functionality for iterative prompting loss 168 # 169 170 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 171 # Batched mask and logit (low-res mask) predictions. 172 masks = torch.stack([m["masks"] for m in batched_outputs]) 173 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 174 175 # Determine the best IOU across the multi-object prediction axis 176 # and turn this into a mask we can use for indexing. 177 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 178 # for details on the indexing logic. 179 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 180 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 181 182 # Index the mask and logits with the best iou indices. 183 # Note that we squash the first two axes (batch x objects) into one when indexing. 184 # That's why we need to reshape bax into (batch x objects) using a view. 185 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 186 batch_size, n_objects = masks.shape[:2] 187 h, w = masks.shape[-2:] 188 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 189 190 h, w = logits.shape[-2:] 191 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 192 193 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 194 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 195 masks = (masks > 0.0).float() 196 return masks, logits 197 198 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 199 """Compute the loss for several (sub-)iterations of iterative prompting. 200 In each iterations the prompts are updated based on the previous predictions. 201 """ 202 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 203 204 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 205 206 for i in range(0, num_subiter): 207 # We do multimasking only in the first sub-iteration as we then pass single prompt 208 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 209 batched_outputs = self.model(batched_inputs, 210 image_embeddings=image_embeddings, 211 multimask_output=multimask_output if i == 0 else False) 212 213 # Compute loss for tis sub-iteration. 214 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 215 216 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 217 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 218 with torch.no_grad(): 219 net_mean_model_iou = torch.mean(batched_iou_predictions) 220 221 loss += net_loss 222 mask_loss += net_mask_loss 223 iou_regression_loss += net_iou_regression_loss 224 mean_model_iou += net_mean_model_iou 225 226 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 227 # Determine the next prompts based on current predictions. 228 with torch.no_grad(): 229 # Get the mask and logit predictions corresponding to the predicted object 230 # (per actual object) with the best IOU. 231 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 232 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) 233 234 loss = loss / num_subiter 235 mask_loss = mask_loss / num_subiter 236 iou_regression_loss = iou_regression_loss / num_subiter 237 mean_model_iou = mean_model_iou / num_subiter 238 239 return loss, mask_loss, iou_regression_loss, mean_model_iou 240 241 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): 242 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 243 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 244 # here, we get each object in the pairs and do the point choices per-object 245 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 246 247 # convert the point coordinates to the expected resolution for iterative prompting 248 # NOTE: 249 # - "only" need to transform the point prompts from the iterative prompting 250 # - the `logits` are the low res masks (256, 256), hence do not need the transform 251 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 252 253 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 254 if "point_coords" in _inp.keys() else net_coords 255 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 256 if "point_labels" in _inp.keys() else net_labels 257 258 _inp["point_coords"] = updated_point_coords 259 _inp["point_labels"] = updated_point_labels 260 261 if self.mask_prob > 0: 262 # using mask inputs for iterative prompting while training, with a probability 263 use_mask_inputs = (random.random() < self.mask_prob) 264 if use_mask_inputs: 265 _inp["mask_inputs"] = logits 266 else: # remove previously existing mask inputs to avoid using them in next sub-iteration 267 _inp.pop("mask_inputs", None) 268 269 return batched_inputs 270 271 # 272 # Training Loop 273 # 274 275 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 276 """Compute one hot target (one mask per channel) for the sampled ids 277 and restrict the number of sampled objects to the minimal number in the batch. 278 """ 279 assert len(y) == len(sampled_ids) 280 281 # Get the minimal number of objects in this batch. 282 # The number of objects in a patch might be < n_objects_per_batch. 283 # This is why we need to restrict it here to ensure the same 284 # number of objects across the batch. 285 n_objects = min(len(ids) for ids in sampled_ids) 286 287 y = y.to(self.device, non_blocking=True) 288 # Compute the one hot targets for the seg-id. 289 y_one_hot = torch.stack([ 290 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 291 for target, ids in zip(y, sampled_ids) 292 ]).float() 293 294 # Also restrict the prompts to the number of objects. 295 batched_inputs = [ 296 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 297 for inp in batched_inputs 298 ] 299 return batched_inputs, y_one_hot 300 301 def _interactive_train_iteration(self, x, y): 302 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 303 304 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 305 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 306 307 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 308 batched_inputs, y_one_hot, 309 num_subiter=self.n_sub_iteration, multimask_output=multimask_output 310 ) 311 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 312 313 def _check_input_normalization(self, x, input_check_done): 314 # The expected data range of the SAM model is 8bit (0-255). 315 # It can easily happen that data is normalized beforehand in training. 316 # For some reasons we don't fully understand this still works, but it 317 # should still be avoided and is very detrimental in some settings 318 # (e.g. when freezing the image encoder) 319 # We check once per epoch if the data seems to be normalized already and 320 # raise a warning if this is the case. 321 if not input_check_done: 322 data_min, data_max = x.min(), x.max() 323 if (data_min < 0) or (data_max < 1): 324 warnings.warn( 325 "It looks like you are normalizing the training data." 326 "The SAM model takes care of normalization, so it is better to not do this." 327 "We recommend to remove data normalization and input data in the range [0, 255]." 328 ) 329 input_check_done = True 330 331 return input_check_done 332 333 def _train_epoch_impl(self, progress, forward_context, backprop): 334 self.model.train() 335 336 input_check_done = False 337 338 n_iter = 0 339 t_per_iter = time.time() 340 for x, y in self.train_loader: 341 input_check_done = self._check_input_normalization(x, input_check_done) 342 343 self.optimizer.zero_grad() 344 345 with forward_context(): 346 (loss, mask_loss, iou_regression_loss, model_iou, 347 sampled_binary_y) = self._interactive_train_iteration(x, y) 348 349 backprop(loss) 350 351 if self.logger is not None: 352 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 353 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 354 self.logger.log_train(self._iteration, loss, lr, x, y, samples, 355 mask_loss, iou_regression_loss, model_iou) 356 357 self._iteration += 1 358 n_iter += 1 359 if self._iteration >= self.max_iteration: 360 break 361 progress.update(1) 362 363 t_per_iter = (time.time() - t_per_iter) / n_iter 364 return t_per_iter 365 366 def _interactive_val_iteration(self, x, y, val_iteration): 367 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 368 369 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 370 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 371 372 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 373 374 batched_outputs = self.model( 375 batched_inputs, 376 image_embeddings=image_embeddings, 377 multimask_output=multimask_output, 378 ) 379 380 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 381 # We use the dice loss over the masks as metric. 382 metric = mask_loss 383 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 384 385 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 386 387 def _validate_impl(self, forward_context): 388 self.model.eval() 389 390 input_check_done = False 391 392 val_iteration = 0 393 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 394 395 with torch.no_grad(): 396 for x, y in self.val_loader: 397 input_check_done = self._check_input_normalization(x, input_check_done) 398 399 with forward_context(): 400 (loss, mask_loss, iou_regression_loss, model_iou, 401 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 402 403 loss_val += loss.item() 404 metric_val += metric.item() 405 model_iou_val += model_iou.item() 406 val_iteration += 1 407 408 loss_val /= len(self.val_loader) 409 metric_val /= len(self.val_loader) 410 model_iou_val /= len(self.val_loader) 411 print() 412 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 413 414 if self.logger is not None: 415 self.logger.log_validation( 416 self._iteration, metric_val, loss_val, x, y, 417 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 418 ) 419 420 return metric_val 421 422 423class SamLogger(TorchEmLogger): 424 """@private""" 425 def __init__(self, trainer, save_root, **unused_kwargs): 426 super().__init__(trainer, save_root) 427 self.log_dir = f"./logs/{trainer.name}" if save_root is None else\ 428 os.path.join(save_root, "logs", trainer.name) 429 os.makedirs(self.log_dir, exist_ok=True) 430 431 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 432 self.log_image_interval = trainer.log_image_interval 433 434 def add_image(self, x, y, samples, name, step): 435 self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) 436 self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) 437 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 438 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 439 440 def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): 441 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 442 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 443 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 444 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 445 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 446 if step % self.log_image_interval == 0: 447 self.add_image(x, y, samples, "train", step) 448 449 def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): 450 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 451 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 452 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 453 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 454 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 455 self.add_image(x, y, samples, "validation", step)
class
SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
19class SamTrainer(torch_em.trainer.DefaultTrainer): 20 """Trainer class for training the Segment Anything model. 21 22 This class is derived from `torch_em.trainer.DefaultTrainer`. 23 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 24 for details on its usage and implementation. 25 26 Args: 27 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 28 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 29 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 30 In each sub-iteration new point prompts are sampled where the model was wrong. 31 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 32 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 33 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 34 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training 35 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) 36 mask_loss: The loss to compare the predicted masks and the targets. 37 kwargs: The keyword arguments of the DefaultTrainer super class. 38 """ 39 40 def __init__( 41 self, 42 convert_inputs, 43 n_sub_iteration: int, 44 n_objects_per_batch: Optional[int] = None, 45 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 46 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 47 mask_prob: float = 0.5, 48 mask_loss: Optional[torch.nn.Module] = None, 49 **kwargs 50 ): 51 if mask_loss is None: 52 # We have to use the Dice Loss with reduce channel set to None. 53 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 54 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 55 else: 56 self.mask_loss = mask_loss 57 58 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 59 self.convert_inputs = convert_inputs 60 self.mse_loss = mse_loss 61 self.n_objects_per_batch = n_objects_per_batch 62 self.n_sub_iteration = n_sub_iteration 63 self.prompt_generator = prompt_generator 64 self.mask_prob = mask_prob 65 self._kwargs = kwargs 66 67 def _get_prompt_and_multimasking_choices(self, current_iteration): 68 """Choose the type of prompts we sample for training, and then we call 69 'convert_inputs' with the correct prompting from here. 70 """ 71 if current_iteration % 2 == 0: # sample only a single point per object 72 n_pos, n_neg = 1, 0 73 get_boxes = False 74 multimask_output = True 75 76 else: # sample only a single box per object 77 n_pos, n_neg = 0, 0 78 get_boxes = True 79 multimask_output = False 80 81 return n_pos, n_neg, get_boxes, multimask_output 82 83 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 84 """Choose the type of prompts we sample for validation, and then we call 85 'convert_inputs' with the correct prompting from here. 86 """ 87 if current_iteration % 4 == 0: # sample only a single point per object 88 n_pos, n_neg = 1, 0 89 get_boxes = False 90 multimask_output = True 91 92 elif current_iteration % 4 == 1: # sample only a single box per object 93 n_pos, n_neg = 0, 0 94 get_boxes = True 95 multimask_output = False 96 97 elif current_iteration % 4 == 2: # sample a random no. of points 98 pos_range, neg_range = 4, 4 99 100 n_pos = np.random.randint(1, pos_range + 1) 101 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 102 n_neg = np.random.randint(1, neg_range + 1) 103 else: 104 n_neg = np.random.randint(0, neg_range + 1) 105 get_boxes = False 106 multimask_output = False 107 108 else: # sample boxes AND random no. of points 109 # here we can have (1, 0) because we also have box 110 pos_range, neg_range = 4, 4 111 112 n_pos = np.random.randint(1, pos_range + 1) 113 n_neg = np.random.randint(0, neg_range + 1) 114 get_boxes = True 115 multimask_output = False 116 117 return n_pos, n_neg, get_boxes, multimask_output 118 119 def _compute_iou(self, pred, true, eps=1e-7): 120 """Compute the IoU score between the prediction and target. 121 """ 122 pred_mask = pred > 0.5 # binarizing the output predictions 123 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 124 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 125 iou = overlap / (union + eps) 126 return iou 127 128 def _compute_loss(self, batched_outputs, y_one_hot): 129 """Compute the loss for one iteration. The loss is made up of two components: 130 - The mask loss: dice score between the predicted masks and targets. 131 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 132 """ 133 mask_loss, iou_regression_loss = 0.0, 0.0 134 135 # Loop over the batch. 136 for batch_output, targets in zip(batched_outputs, y_one_hot): 137 138 predicted_objects = torch.sigmoid(batch_output["masks"]) 139 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 140 # We swap the axes that go into the dice loss so that the object axis 141 # corresponds to the channel axes. This ensures that the dice is computed 142 # independetly per channel. We do not reduce the channel axis in the dice, 143 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 144 dice_scores = torch.stack([ 145 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 146 for i in range(predicted_objects.shape[1]) 147 ]) 148 dice_scores, _ = torch.min(dice_scores, dim=0) 149 150 # Compute the actual IOU between the predicted and true objects. 151 # The outer loop is for the 1 or 3 predicted masks per true object. 152 with torch.no_grad(): 153 true_iou = torch.stack([ 154 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 155 ]) 156 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 157 # the object axis is back in the first dimension. 158 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 159 160 mask_loss = mask_loss + torch.mean(dice_scores) 161 iou_regression_loss = iou_regression_loss + iou_score 162 163 loss = mask_loss + iou_regression_loss 164 165 return loss, mask_loss, iou_regression_loss 166 167 # 168 # Functionality for iterative prompting loss 169 # 170 171 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 172 # Batched mask and logit (low-res mask) predictions. 173 masks = torch.stack([m["masks"] for m in batched_outputs]) 174 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 175 176 # Determine the best IOU across the multi-object prediction axis 177 # and turn this into a mask we can use for indexing. 178 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 179 # for details on the indexing logic. 180 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 181 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 182 183 # Index the mask and logits with the best iou indices. 184 # Note that we squash the first two axes (batch x objects) into one when indexing. 185 # That's why we need to reshape bax into (batch x objects) using a view. 186 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 187 batch_size, n_objects = masks.shape[:2] 188 h, w = masks.shape[-2:] 189 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 190 191 h, w = logits.shape[-2:] 192 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 193 194 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 195 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 196 masks = (masks > 0.0).float() 197 return masks, logits 198 199 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 200 """Compute the loss for several (sub-)iterations of iterative prompting. 201 In each iterations the prompts are updated based on the previous predictions. 202 """ 203 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 204 205 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 206 207 for i in range(0, num_subiter): 208 # We do multimasking only in the first sub-iteration as we then pass single prompt 209 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 210 batched_outputs = self.model(batched_inputs, 211 image_embeddings=image_embeddings, 212 multimask_output=multimask_output if i == 0 else False) 213 214 # Compute loss for tis sub-iteration. 215 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 216 217 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 218 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 219 with torch.no_grad(): 220 net_mean_model_iou = torch.mean(batched_iou_predictions) 221 222 loss += net_loss 223 mask_loss += net_mask_loss 224 iou_regression_loss += net_iou_regression_loss 225 mean_model_iou += net_mean_model_iou 226 227 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 228 # Determine the next prompts based on current predictions. 229 with torch.no_grad(): 230 # Get the mask and logit predictions corresponding to the predicted object 231 # (per actual object) with the best IOU. 232 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 233 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) 234 235 loss = loss / num_subiter 236 mask_loss = mask_loss / num_subiter 237 iou_regression_loss = iou_regression_loss / num_subiter 238 mean_model_iou = mean_model_iou / num_subiter 239 240 return loss, mask_loss, iou_regression_loss, mean_model_iou 241 242 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks): 243 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 244 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 245 # here, we get each object in the pairs and do the point choices per-object 246 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 247 248 # convert the point coordinates to the expected resolution for iterative prompting 249 # NOTE: 250 # - "only" need to transform the point prompts from the iterative prompting 251 # - the `logits` are the low res masks (256, 256), hence do not need the transform 252 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 253 254 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 255 if "point_coords" in _inp.keys() else net_coords 256 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 257 if "point_labels" in _inp.keys() else net_labels 258 259 _inp["point_coords"] = updated_point_coords 260 _inp["point_labels"] = updated_point_labels 261 262 if self.mask_prob > 0: 263 # using mask inputs for iterative prompting while training, with a probability 264 use_mask_inputs = (random.random() < self.mask_prob) 265 if use_mask_inputs: 266 _inp["mask_inputs"] = logits 267 else: # remove previously existing mask inputs to avoid using them in next sub-iteration 268 _inp.pop("mask_inputs", None) 269 270 return batched_inputs 271 272 # 273 # Training Loop 274 # 275 276 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 277 """Compute one hot target (one mask per channel) for the sampled ids 278 and restrict the number of sampled objects to the minimal number in the batch. 279 """ 280 assert len(y) == len(sampled_ids) 281 282 # Get the minimal number of objects in this batch. 283 # The number of objects in a patch might be < n_objects_per_batch. 284 # This is why we need to restrict it here to ensure the same 285 # number of objects across the batch. 286 n_objects = min(len(ids) for ids in sampled_ids) 287 288 y = y.to(self.device, non_blocking=True) 289 # Compute the one hot targets for the seg-id. 290 y_one_hot = torch.stack([ 291 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 292 for target, ids in zip(y, sampled_ids) 293 ]).float() 294 295 # Also restrict the prompts to the number of objects. 296 batched_inputs = [ 297 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 298 for inp in batched_inputs 299 ] 300 return batched_inputs, y_one_hot 301 302 def _interactive_train_iteration(self, x, y): 303 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 304 305 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 306 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 307 308 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 309 batched_inputs, y_one_hot, 310 num_subiter=self.n_sub_iteration, multimask_output=multimask_output 311 ) 312 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 313 314 def _check_input_normalization(self, x, input_check_done): 315 # The expected data range of the SAM model is 8bit (0-255). 316 # It can easily happen that data is normalized beforehand in training. 317 # For some reasons we don't fully understand this still works, but it 318 # should still be avoided and is very detrimental in some settings 319 # (e.g. when freezing the image encoder) 320 # We check once per epoch if the data seems to be normalized already and 321 # raise a warning if this is the case. 322 if not input_check_done: 323 data_min, data_max = x.min(), x.max() 324 if (data_min < 0) or (data_max < 1): 325 warnings.warn( 326 "It looks like you are normalizing the training data." 327 "The SAM model takes care of normalization, so it is better to not do this." 328 "We recommend to remove data normalization and input data in the range [0, 255]." 329 ) 330 input_check_done = True 331 332 return input_check_done 333 334 def _train_epoch_impl(self, progress, forward_context, backprop): 335 self.model.train() 336 337 input_check_done = False 338 339 n_iter = 0 340 t_per_iter = time.time() 341 for x, y in self.train_loader: 342 input_check_done = self._check_input_normalization(x, input_check_done) 343 344 self.optimizer.zero_grad() 345 346 with forward_context(): 347 (loss, mask_loss, iou_regression_loss, model_iou, 348 sampled_binary_y) = self._interactive_train_iteration(x, y) 349 350 backprop(loss) 351 352 if self.logger is not None: 353 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 354 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 355 self.logger.log_train(self._iteration, loss, lr, x, y, samples, 356 mask_loss, iou_regression_loss, model_iou) 357 358 self._iteration += 1 359 n_iter += 1 360 if self._iteration >= self.max_iteration: 361 break 362 progress.update(1) 363 364 t_per_iter = (time.time() - t_per_iter) / n_iter 365 return t_per_iter 366 367 def _interactive_val_iteration(self, x, y, val_iteration): 368 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 369 370 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 371 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 372 373 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 374 375 batched_outputs = self.model( 376 batched_inputs, 377 image_embeddings=image_embeddings, 378 multimask_output=multimask_output, 379 ) 380 381 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 382 # We use the dice loss over the masks as metric. 383 metric = mask_loss 384 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 385 386 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 387 388 def _validate_impl(self, forward_context): 389 self.model.eval() 390 391 input_check_done = False 392 393 val_iteration = 0 394 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 395 396 with torch.no_grad(): 397 for x, y in self.val_loader: 398 input_check_done = self._check_input_normalization(x, input_check_done) 399 400 with forward_context(): 401 (loss, mask_loss, iou_regression_loss, model_iou, 402 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 403 404 loss_val += loss.item() 405 metric_val += metric.item() 406 model_iou_val += model_iou.item() 407 val_iteration += 1 408 409 loss_val /= len(self.val_loader) 410 metric_val /= len(self.val_loader) 411 model_iou_val /= len(self.val_loader) 412 print() 413 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 414 415 if self.logger is not None: 416 self.logger.log_validation( 417 self._iteration, metric_val, loss_val, x, y, 418 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 419 ) 420 421 return metric_val
Trainer class for training the Segment Anything model.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSamInputs
can be used here. - n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
- n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
- mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
- prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
- mask_prob: The probability of using the mask inputs in the iterative prompting (per
n_sub_iteration
) - mask_loss: The loss to compare the predicted masks and the targets.
- kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
40 def __init__( 41 self, 42 convert_inputs, 43 n_sub_iteration: int, 44 n_objects_per_batch: Optional[int] = None, 45 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 46 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 47 mask_prob: float = 0.5, 48 mask_loss: Optional[torch.nn.Module] = None, 49 **kwargs 50 ): 51 if mask_loss is None: 52 # We have to use the Dice Loss with reduce channel set to None. 53 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 54 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 55 else: 56 self.mask_loss = mask_loss 57 58 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 59 self.convert_inputs = convert_inputs 60 self.mse_loss = mse_loss 61 self.n_objects_per_batch = n_objects_per_batch 62 self.n_sub_iteration = n_sub_iteration 63 self.prompt_generator = prompt_generator 64 self.mask_prob = mask_prob 65 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- from_checkpoint
- Serializer
- save_checkpoint
- load_checkpoint
- fit