micro_sam.training.sam_trainer
1import os 2import time 3import random 4import warnings 5from typing import Optional, Callable 6 7import numpy as np 8 9import torch 10from torchvision.utils import make_grid 11 12import torch_em 13from torch_em.trainer.logger_base import TorchEmLogger 14 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator 16 17 18class SamTrainer(torch_em.trainer.DefaultTrainer): 19 """Trainer class for training the Segment Anything model. 20 21 This class is derived from `torch_em.trainer.DefaultTrainer`. 22 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 23 for details on its usage and implementation. 24 25 Args: 26 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 27 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 28 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 29 In each sub-iteration new point prompts are sampled where the model was wrong. 30 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 31 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 32 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 33 By default, set to the expected mse loss function. 34 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. 35 Already allocated with the desired prompt generator by default. 36 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`). 37 By default, set to '0.5'. 38 mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function. 39 kwargs: The keyword arguments of the `DefaultTrainer` super class. 40 """ 41 42 def __init__( 43 self, 44 convert_inputs: Callable, 45 n_sub_iteration: int, 46 n_objects_per_batch: Optional[int] = None, 47 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 48 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 49 mask_prob: float = 0.5, 50 mask_loss: Optional[torch.nn.Module] = None, 51 **kwargs 52 ): 53 if mask_loss is None: 54 # We have to use the Dice Loss with reduce channel set to None. 55 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 56 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 57 else: 58 self.mask_loss = mask_loss 59 60 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 61 self.convert_inputs = convert_inputs 62 self.mse_loss = mse_loss 63 self.n_objects_per_batch = n_objects_per_batch 64 self.n_sub_iteration = n_sub_iteration 65 self.prompt_generator = prompt_generator 66 self.mask_prob = mask_prob 67 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 68 self._kwargs = kwargs 69 70 def _get_prompt_and_multimasking_choices(self, current_iteration): 71 """Choose the type of prompts we sample for training, and then we call 72 'convert_inputs' with the correct prompting from here. 73 """ 74 if current_iteration % 2 == 0: # sample only a single point per object 75 n_pos, n_neg = 1, 0 76 get_boxes = False 77 multimask_output = True 78 79 else: # sample only a single box per object 80 n_pos, n_neg = 0, 0 81 get_boxes = True 82 multimask_output = False 83 84 return n_pos, n_neg, get_boxes, multimask_output 85 86 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 87 """Choose the type of prompts we sample for validation, and then we call 88 'convert_inputs' with the correct prompting from here. 89 """ 90 if current_iteration % 4 == 0: # sample only a single point per object 91 n_pos, n_neg = 1, 0 92 get_boxes = False 93 multimask_output = True 94 95 elif current_iteration % 4 == 1: # sample only a single box per object 96 n_pos, n_neg = 0, 0 97 get_boxes = True 98 multimask_output = False 99 100 elif current_iteration % 4 == 2: # sample a random no. of points 101 pos_range, neg_range = 4, 4 102 103 n_pos = np.random.randint(1, pos_range + 1) 104 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 105 n_neg = np.random.randint(1, neg_range + 1) 106 else: 107 n_neg = np.random.randint(0, neg_range + 1) 108 get_boxes = False 109 multimask_output = False 110 111 else: # sample boxes AND random no. of points 112 # here we can have (1, 0) because we also have box 113 pos_range, neg_range = 4, 4 114 115 n_pos = np.random.randint(1, pos_range + 1) 116 n_neg = np.random.randint(0, neg_range + 1) 117 get_boxes = True 118 multimask_output = False 119 120 return n_pos, n_neg, get_boxes, multimask_output 121 122 def _compute_iou(self, pred, true, eps=1e-7): 123 """Compute the IoU score between the prediction and target. 124 """ 125 pred_mask = pred > 0.5 # binarizing the output predictions 126 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 127 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 128 iou = overlap / (union + eps) 129 return iou 130 131 def _compute_loss(self, batched_outputs, y_one_hot): 132 """Compute the loss for one iteration. The loss is made up of two components: 133 - The mask loss: dice score between the predicted masks and targets. 134 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 135 """ 136 mask_loss, iou_regression_loss = 0.0, 0.0 137 138 # Loop over the batch. 139 for batch_output, targets in zip(batched_outputs, y_one_hot): 140 141 predicted_objects = torch.sigmoid(batch_output["masks"]) 142 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 143 # We swap the axes that go into the dice loss so that the object axis 144 # corresponds to the channel axes. This ensures that the dice is computed 145 # independetly per channel. We do not reduce the channel axis in the dice, 146 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 147 dice_scores = torch.stack([ 148 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 149 for i in range(predicted_objects.shape[1]) 150 ]) 151 dice_scores, _ = torch.min(dice_scores, dim=0) 152 153 # Compute the actual IOU between the predicted and true objects. 154 # The outer loop is for the 1 or 3 predicted masks per true object. 155 with torch.no_grad(): 156 true_iou = torch.stack([ 157 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 158 ]) 159 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 160 # the object axis is back in the first dimension. 161 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 162 163 mask_loss = mask_loss + torch.mean(dice_scores) 164 iou_regression_loss = iou_regression_loss + iou_score 165 166 loss = mask_loss + iou_regression_loss 167 168 return loss, mask_loss, iou_regression_loss 169 170 # 171 # Functionality for iterative prompting loss 172 # 173 174 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 175 # Batched mask and logit (low-res mask) predictions. 176 masks = torch.stack([m["masks"] for m in batched_outputs]) 177 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 178 179 # Determine the best IOU across the multi-object prediction axis 180 # and turn this into a mask we can use for indexing. 181 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 182 # for details on the indexing logic. 183 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 184 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 185 186 # Index the mask and logits with the best iou indices. 187 # Note that we squash the first two axes (batch x objects) into one when indexing. 188 # That's why we need to reshape bax into (batch x objects) using a view. 189 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 190 batch_size, n_objects = masks.shape[:2] 191 h, w = masks.shape[-2:] 192 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 193 194 h, w = logits.shape[-2:] 195 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 196 197 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 198 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 199 masks = (masks > 0.0).float() 200 return masks, logits 201 202 def _use_mask_inputs(self, batched_inputs, y_one_hot): 203 # Whether to use masks per training top-iteration. 204 use_mask_inputs = False # determines if each sub-iteration will use mask inputs as prompts or not. 205 use_zero_mask = False # determines if the zeroth iteration will use zeros as mask inputs. 206 207 if self.mask_prob == 1: # i.e. always use masks. 208 use_mask_inputs = True # we would like to use mask inputs in all sub-iterations. 209 use_zero_mask = self.is_data_parallel # we would like to use zeros as mask inputs for zeroth iteration. 210 211 elif self.mask_prob > 0: # i.e. if we use mask inputs with a probability. 212 if self.is_data_parallel: # if training on multiple GPUs. 213 if torch.distributed.get_rank() == 0: # device with rank 0. 214 use_mask_inputs_tensor = torch.tensor( 215 random.random() < self.mask_prob, dtype=torch.uint8, device=self.device, 216 ) 217 else: # on other devices, we do not need this parameter at this stage. 218 use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device) 219 220 # Broadcast the value to all devices (ranks). 221 torch.distributed.broadcast(use_mask_inputs_tensor, src=0) 222 223 # And convert it back to our desired boolean value. 224 use_mask_inputs = bool(use_mask_inputs_tensor.item()) 225 use_zero_mask = use_mask_inputs # provides zeros as mask inputs. 226 else: # training on a single GPU. 227 use_mask_inputs = None 228 229 if use_zero_mask: 230 # We use zeros as mask inputs for the zeroth iteration. 231 y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256)) 232 233 # Add zeros as mask inputs to batched inputs. 234 for bi, curr_masks in zip(batched_inputs, y_zeros): 235 bi["mask_inputs"] = curr_masks 236 237 return batched_inputs, use_mask_inputs 238 239 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 240 """Compute the loss for several (sub-)iterations of iterative prompting. 241 In each iterations the prompts are updated based on the previous predictions. 242 """ 243 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 244 245 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 246 247 # Whether to use mask inputs in each sub-iteration. 248 batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot) 249 250 for i in range(0, num_subiter): 251 # We do multimasking only in the first sub-iteration as we then pass single prompt 252 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 253 batched_outputs = self.model( 254 batched_inputs=batched_inputs, 255 image_embeddings=image_embeddings, 256 multimask_output=multimask_output if i == 0 else False, 257 ) 258 259 # Compute loss for this sub-iteration. 260 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 261 262 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 263 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 264 with torch.no_grad(): 265 net_mean_model_iou = torch.mean(batched_iou_predictions) 266 267 loss = loss + net_loss 268 mask_loss = mask_loss + net_mask_loss 269 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 270 mean_model_iou = mean_model_iou + net_mean_model_iou 271 272 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 273 # Determine the next prompts based on current predictions. 274 with torch.no_grad(): 275 # Get the mask and logit predictions corresponding to the predicted object 276 # (per actual object) with the best IOU. 277 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 278 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs) 279 280 loss = loss / num_subiter 281 mask_loss = mask_loss / num_subiter 282 iou_regression_loss = iou_regression_loss / num_subiter 283 mean_model_iou = mean_model_iou / num_subiter 284 285 return loss, mask_loss, iou_regression_loss, mean_model_iou 286 287 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs): 288 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 289 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 290 # here, we get each object in the pairs and do the point choices per-object 291 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 292 293 # convert the point coordinates to the expected resolution for iterative prompting 294 # NOTE: 295 # - "only" need to transform the point prompts from the iterative prompting 296 # - the `logits` are the low res masks (256, 256), hence do not need the transform 297 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 298 299 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 300 if "point_coords" in _inp.keys() else net_coords 301 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 302 if "point_labels" in _inp.keys() else net_labels 303 304 _inp["point_coords"] = updated_point_coords 305 _inp["point_labels"] = updated_point_labels 306 307 if self.is_data_parallel: # multi-GPU training 308 use_mask_inputs_this_iter = use_mask_inputs 309 else: # single GPU training 310 if self.mask_prob > 0: 311 # using mask inputs for iterative prompting while training, with a probability 312 use_mask_inputs = (random.random() < self.mask_prob) 313 else: # otherwise we assume it is 0 and do not need the generator to decide. 314 use_mask_inputs = False 315 316 use_mask_inputs_this_iter = use_mask_inputs 317 318 if use_mask_inputs_this_iter: 319 _inp["mask_inputs"] = logits 320 else: # remove previously existing mask inputs to avoid using them in next sub-iteration. 321 _inp.pop("mask_inputs", None) 322 323 return batched_inputs 324 325 # 326 # Training Loop 327 # 328 329 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 330 """Compute one hot target (one mask per channel) for the sampled ids 331 and restrict the number of sampled objects to the minimal number in the batch. 332 """ 333 assert len(y) == len(sampled_ids) 334 335 # Get the minimal number of objects in this batch. 336 # The number of objects in a patch might be < n_objects_per_batch. 337 # This is why we need to restrict it here to ensure the same 338 # number of objects across the batch. 339 n_objects = min(len(ids) for ids in sampled_ids) 340 341 y = y.to(self.device, non_blocking=True) 342 # Compute the one hot targets for the seg-id. 343 y_one_hot = torch.stack([ 344 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 345 for target, ids in zip(y, sampled_ids) 346 ]).float() 347 348 # Also restrict the prompts to the number of objects. 349 batched_inputs = [ 350 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 351 for inp in batched_inputs 352 ] 353 return batched_inputs, y_one_hot 354 355 def _interactive_train_iteration(self, x, y): 356 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 357 358 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 359 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 360 361 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 362 batched_inputs=batched_inputs, 363 y_one_hot=y_one_hot, 364 num_subiter=self.n_sub_iteration, 365 multimask_output=multimask_output 366 ) 367 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 368 369 def _check_input_normalization(self, x, input_check_done): 370 # The expected data range of the SAM model is 8bit (0-255). 371 # It can easily happen that data is normalized beforehand in training. 372 # For some reasons we don't fully understand this still works, but it 373 # should still be avoided and is very detrimental in some settings 374 # (e.g. when freezing the image encoder) 375 # We check once per epoch if the data seems to be normalized already and 376 # raise a warning if this is the case. 377 if not input_check_done: 378 data_min, data_max = x.min(), x.max() 379 if (data_min < 0) or (data_max < 1): 380 warnings.warn( 381 "It looks like you are normalizing the training data. " 382 "The SAM model takes care of normalization, so it is better to not do this. " 383 "We recommend to remove data normalization and input data in the range [0, 255]." 384 ) 385 input_check_done = True 386 387 return input_check_done 388 389 def _train_epoch_impl(self, progress, forward_context, backprop): 390 self.model.train() 391 392 input_check_done = False 393 394 n_iter = 0 395 t_per_iter = time.time() 396 for x, y in self.train_loader: 397 input_check_done = self._check_input_normalization(x, input_check_done) 398 399 self.optimizer.zero_grad() 400 401 with forward_context(): 402 (loss, mask_loss, iou_regression_loss, model_iou, 403 sampled_binary_y) = self._interactive_train_iteration(x, y) 404 405 backprop(loss) 406 407 if self.logger is not None: 408 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 409 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 410 self.logger.log_train( 411 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 412 ) 413 414 self._iteration += 1 415 n_iter += 1 416 if self._iteration >= self.max_iteration: 417 break 418 progress.update(1) 419 420 t_per_iter = (time.time() - t_per_iter) / n_iter 421 return t_per_iter 422 423 def _interactive_val_iteration(self, x, y, val_iteration): 424 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 425 426 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 427 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 428 429 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 430 431 batched_outputs = self.model( 432 batched_inputs=batched_inputs, 433 image_embeddings=image_embeddings, 434 multimask_output=multimask_output, 435 ) 436 437 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 438 # We use the dice loss over the masks as metric. 439 metric = mask_loss 440 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 441 442 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 443 444 def _validate_impl(self, forward_context): 445 self.model.eval() 446 447 input_check_done = False 448 449 val_iteration = 0 450 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 451 452 with torch.no_grad(): 453 for x, y in self.val_loader: 454 input_check_done = self._check_input_normalization(x, input_check_done) 455 456 with forward_context(): 457 (loss, mask_loss, iou_regression_loss, model_iou, 458 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 459 460 loss_val += loss.item() 461 metric_val += metric.item() 462 model_iou_val += model_iou.item() 463 val_iteration += 1 464 465 loss_val /= len(self.val_loader) 466 metric_val /= len(self.val_loader) 467 model_iou_val /= len(self.val_loader) 468 print() 469 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 470 471 if self.logger is not None: 472 self.logger.log_validation( 473 self._iteration, metric_val, loss_val, x, y, 474 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 475 ) 476 477 return metric_val 478 479 480class SamLogger(TorchEmLogger): 481 """@private""" 482 def __init__(self, trainer, save_root, **unused_kwargs): 483 super().__init__(trainer, save_root) 484 self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name) 485 os.makedirs(self.log_dir, exist_ok=True) 486 487 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 488 self.log_image_interval = trainer.log_image_interval 489 490 def add_image(self, x, y, samples, name, step): 491 self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) 492 self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) 493 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 494 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 495 496 def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): 497 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 498 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 499 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 500 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 501 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 502 if step % self.log_image_interval == 0: 503 self.add_image(x, y, samples, "train", step) 504 505 def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): 506 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 507 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 508 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 509 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 510 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 511 self.add_image(x, y, samples, "validation", step)
class
SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
19class SamTrainer(torch_em.trainer.DefaultTrainer): 20 """Trainer class for training the Segment Anything model. 21 22 This class is derived from `torch_em.trainer.DefaultTrainer`. 23 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 24 for details on its usage and implementation. 25 26 Args: 27 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 28 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 29 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 30 In each sub-iteration new point prompts are sampled where the model was wrong. 31 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 32 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 33 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 34 By default, set to the expected mse loss function. 35 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. 36 Already allocated with the desired prompt generator by default. 37 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`). 38 By default, set to '0.5'. 39 mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function. 40 kwargs: The keyword arguments of the `DefaultTrainer` super class. 41 """ 42 43 def __init__( 44 self, 45 convert_inputs: Callable, 46 n_sub_iteration: int, 47 n_objects_per_batch: Optional[int] = None, 48 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 49 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 50 mask_prob: float = 0.5, 51 mask_loss: Optional[torch.nn.Module] = None, 52 **kwargs 53 ): 54 if mask_loss is None: 55 # We have to use the Dice Loss with reduce channel set to None. 56 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 57 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 58 else: 59 self.mask_loss = mask_loss 60 61 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 62 self.convert_inputs = convert_inputs 63 self.mse_loss = mse_loss 64 self.n_objects_per_batch = n_objects_per_batch 65 self.n_sub_iteration = n_sub_iteration 66 self.prompt_generator = prompt_generator 67 self.mask_prob = mask_prob 68 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 69 self._kwargs = kwargs 70 71 def _get_prompt_and_multimasking_choices(self, current_iteration): 72 """Choose the type of prompts we sample for training, and then we call 73 'convert_inputs' with the correct prompting from here. 74 """ 75 if current_iteration % 2 == 0: # sample only a single point per object 76 n_pos, n_neg = 1, 0 77 get_boxes = False 78 multimask_output = True 79 80 else: # sample only a single box per object 81 n_pos, n_neg = 0, 0 82 get_boxes = True 83 multimask_output = False 84 85 return n_pos, n_neg, get_boxes, multimask_output 86 87 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 88 """Choose the type of prompts we sample for validation, and then we call 89 'convert_inputs' with the correct prompting from here. 90 """ 91 if current_iteration % 4 == 0: # sample only a single point per object 92 n_pos, n_neg = 1, 0 93 get_boxes = False 94 multimask_output = True 95 96 elif current_iteration % 4 == 1: # sample only a single box per object 97 n_pos, n_neg = 0, 0 98 get_boxes = True 99 multimask_output = False 100 101 elif current_iteration % 4 == 2: # sample a random no. of points 102 pos_range, neg_range = 4, 4 103 104 n_pos = np.random.randint(1, pos_range + 1) 105 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 106 n_neg = np.random.randint(1, neg_range + 1) 107 else: 108 n_neg = np.random.randint(0, neg_range + 1) 109 get_boxes = False 110 multimask_output = False 111 112 else: # sample boxes AND random no. of points 113 # here we can have (1, 0) because we also have box 114 pos_range, neg_range = 4, 4 115 116 n_pos = np.random.randint(1, pos_range + 1) 117 n_neg = np.random.randint(0, neg_range + 1) 118 get_boxes = True 119 multimask_output = False 120 121 return n_pos, n_neg, get_boxes, multimask_output 122 123 def _compute_iou(self, pred, true, eps=1e-7): 124 """Compute the IoU score between the prediction and target. 125 """ 126 pred_mask = pred > 0.5 # binarizing the output predictions 127 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 128 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 129 iou = overlap / (union + eps) 130 return iou 131 132 def _compute_loss(self, batched_outputs, y_one_hot): 133 """Compute the loss for one iteration. The loss is made up of two components: 134 - The mask loss: dice score between the predicted masks and targets. 135 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 136 """ 137 mask_loss, iou_regression_loss = 0.0, 0.0 138 139 # Loop over the batch. 140 for batch_output, targets in zip(batched_outputs, y_one_hot): 141 142 predicted_objects = torch.sigmoid(batch_output["masks"]) 143 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 144 # We swap the axes that go into the dice loss so that the object axis 145 # corresponds to the channel axes. This ensures that the dice is computed 146 # independetly per channel. We do not reduce the channel axis in the dice, 147 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 148 dice_scores = torch.stack([ 149 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 150 for i in range(predicted_objects.shape[1]) 151 ]) 152 dice_scores, _ = torch.min(dice_scores, dim=0) 153 154 # Compute the actual IOU between the predicted and true objects. 155 # The outer loop is for the 1 or 3 predicted masks per true object. 156 with torch.no_grad(): 157 true_iou = torch.stack([ 158 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 159 ]) 160 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 161 # the object axis is back in the first dimension. 162 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 163 164 mask_loss = mask_loss + torch.mean(dice_scores) 165 iou_regression_loss = iou_regression_loss + iou_score 166 167 loss = mask_loss + iou_regression_loss 168 169 return loss, mask_loss, iou_regression_loss 170 171 # 172 # Functionality for iterative prompting loss 173 # 174 175 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 176 # Batched mask and logit (low-res mask) predictions. 177 masks = torch.stack([m["masks"] for m in batched_outputs]) 178 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 179 180 # Determine the best IOU across the multi-object prediction axis 181 # and turn this into a mask we can use for indexing. 182 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 183 # for details on the indexing logic. 184 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 185 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 186 187 # Index the mask and logits with the best iou indices. 188 # Note that we squash the first two axes (batch x objects) into one when indexing. 189 # That's why we need to reshape bax into (batch x objects) using a view. 190 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 191 batch_size, n_objects = masks.shape[:2] 192 h, w = masks.shape[-2:] 193 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 194 195 h, w = logits.shape[-2:] 196 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 197 198 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 199 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 200 masks = (masks > 0.0).float() 201 return masks, logits 202 203 def _use_mask_inputs(self, batched_inputs, y_one_hot): 204 # Whether to use masks per training top-iteration. 205 use_mask_inputs = False # determines if each sub-iteration will use mask inputs as prompts or not. 206 use_zero_mask = False # determines if the zeroth iteration will use zeros as mask inputs. 207 208 if self.mask_prob == 1: # i.e. always use masks. 209 use_mask_inputs = True # we would like to use mask inputs in all sub-iterations. 210 use_zero_mask = self.is_data_parallel # we would like to use zeros as mask inputs for zeroth iteration. 211 212 elif self.mask_prob > 0: # i.e. if we use mask inputs with a probability. 213 if self.is_data_parallel: # if training on multiple GPUs. 214 if torch.distributed.get_rank() == 0: # device with rank 0. 215 use_mask_inputs_tensor = torch.tensor( 216 random.random() < self.mask_prob, dtype=torch.uint8, device=self.device, 217 ) 218 else: # on other devices, we do not need this parameter at this stage. 219 use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device) 220 221 # Broadcast the value to all devices (ranks). 222 torch.distributed.broadcast(use_mask_inputs_tensor, src=0) 223 224 # And convert it back to our desired boolean value. 225 use_mask_inputs = bool(use_mask_inputs_tensor.item()) 226 use_zero_mask = use_mask_inputs # provides zeros as mask inputs. 227 else: # training on a single GPU. 228 use_mask_inputs = None 229 230 if use_zero_mask: 231 # We use zeros as mask inputs for the zeroth iteration. 232 y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256)) 233 234 # Add zeros as mask inputs to batched inputs. 235 for bi, curr_masks in zip(batched_inputs, y_zeros): 236 bi["mask_inputs"] = curr_masks 237 238 return batched_inputs, use_mask_inputs 239 240 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 241 """Compute the loss for several (sub-)iterations of iterative prompting. 242 In each iterations the prompts are updated based on the previous predictions. 243 """ 244 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 245 246 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 247 248 # Whether to use mask inputs in each sub-iteration. 249 batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot) 250 251 for i in range(0, num_subiter): 252 # We do multimasking only in the first sub-iteration as we then pass single prompt 253 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 254 batched_outputs = self.model( 255 batched_inputs=batched_inputs, 256 image_embeddings=image_embeddings, 257 multimask_output=multimask_output if i == 0 else False, 258 ) 259 260 # Compute loss for this sub-iteration. 261 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 262 263 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 264 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 265 with torch.no_grad(): 266 net_mean_model_iou = torch.mean(batched_iou_predictions) 267 268 loss = loss + net_loss 269 mask_loss = mask_loss + net_mask_loss 270 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 271 mean_model_iou = mean_model_iou + net_mean_model_iou 272 273 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 274 # Determine the next prompts based on current predictions. 275 with torch.no_grad(): 276 # Get the mask and logit predictions corresponding to the predicted object 277 # (per actual object) with the best IOU. 278 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 279 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs) 280 281 loss = loss / num_subiter 282 mask_loss = mask_loss / num_subiter 283 iou_regression_loss = iou_regression_loss / num_subiter 284 mean_model_iou = mean_model_iou / num_subiter 285 286 return loss, mask_loss, iou_regression_loss, mean_model_iou 287 288 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs): 289 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 290 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 291 # here, we get each object in the pairs and do the point choices per-object 292 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 293 294 # convert the point coordinates to the expected resolution for iterative prompting 295 # NOTE: 296 # - "only" need to transform the point prompts from the iterative prompting 297 # - the `logits` are the low res masks (256, 256), hence do not need the transform 298 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 299 300 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 301 if "point_coords" in _inp.keys() else net_coords 302 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 303 if "point_labels" in _inp.keys() else net_labels 304 305 _inp["point_coords"] = updated_point_coords 306 _inp["point_labels"] = updated_point_labels 307 308 if self.is_data_parallel: # multi-GPU training 309 use_mask_inputs_this_iter = use_mask_inputs 310 else: # single GPU training 311 if self.mask_prob > 0: 312 # using mask inputs for iterative prompting while training, with a probability 313 use_mask_inputs = (random.random() < self.mask_prob) 314 else: # otherwise we assume it is 0 and do not need the generator to decide. 315 use_mask_inputs = False 316 317 use_mask_inputs_this_iter = use_mask_inputs 318 319 if use_mask_inputs_this_iter: 320 _inp["mask_inputs"] = logits 321 else: # remove previously existing mask inputs to avoid using them in next sub-iteration. 322 _inp.pop("mask_inputs", None) 323 324 return batched_inputs 325 326 # 327 # Training Loop 328 # 329 330 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 331 """Compute one hot target (one mask per channel) for the sampled ids 332 and restrict the number of sampled objects to the minimal number in the batch. 333 """ 334 assert len(y) == len(sampled_ids) 335 336 # Get the minimal number of objects in this batch. 337 # The number of objects in a patch might be < n_objects_per_batch. 338 # This is why we need to restrict it here to ensure the same 339 # number of objects across the batch. 340 n_objects = min(len(ids) for ids in sampled_ids) 341 342 y = y.to(self.device, non_blocking=True) 343 # Compute the one hot targets for the seg-id. 344 y_one_hot = torch.stack([ 345 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 346 for target, ids in zip(y, sampled_ids) 347 ]).float() 348 349 # Also restrict the prompts to the number of objects. 350 batched_inputs = [ 351 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 352 for inp in batched_inputs 353 ] 354 return batched_inputs, y_one_hot 355 356 def _interactive_train_iteration(self, x, y): 357 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 358 359 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 360 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 361 362 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 363 batched_inputs=batched_inputs, 364 y_one_hot=y_one_hot, 365 num_subiter=self.n_sub_iteration, 366 multimask_output=multimask_output 367 ) 368 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 369 370 def _check_input_normalization(self, x, input_check_done): 371 # The expected data range of the SAM model is 8bit (0-255). 372 # It can easily happen that data is normalized beforehand in training. 373 # For some reasons we don't fully understand this still works, but it 374 # should still be avoided and is very detrimental in some settings 375 # (e.g. when freezing the image encoder) 376 # We check once per epoch if the data seems to be normalized already and 377 # raise a warning if this is the case. 378 if not input_check_done: 379 data_min, data_max = x.min(), x.max() 380 if (data_min < 0) or (data_max < 1): 381 warnings.warn( 382 "It looks like you are normalizing the training data. " 383 "The SAM model takes care of normalization, so it is better to not do this. " 384 "We recommend to remove data normalization and input data in the range [0, 255]." 385 ) 386 input_check_done = True 387 388 return input_check_done 389 390 def _train_epoch_impl(self, progress, forward_context, backprop): 391 self.model.train() 392 393 input_check_done = False 394 395 n_iter = 0 396 t_per_iter = time.time() 397 for x, y in self.train_loader: 398 input_check_done = self._check_input_normalization(x, input_check_done) 399 400 self.optimizer.zero_grad() 401 402 with forward_context(): 403 (loss, mask_loss, iou_regression_loss, model_iou, 404 sampled_binary_y) = self._interactive_train_iteration(x, y) 405 406 backprop(loss) 407 408 if self.logger is not None: 409 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 410 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 411 self.logger.log_train( 412 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 413 ) 414 415 self._iteration += 1 416 n_iter += 1 417 if self._iteration >= self.max_iteration: 418 break 419 progress.update(1) 420 421 t_per_iter = (time.time() - t_per_iter) / n_iter 422 return t_per_iter 423 424 def _interactive_val_iteration(self, x, y, val_iteration): 425 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 426 427 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 428 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 429 430 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 431 432 batched_outputs = self.model( 433 batched_inputs=batched_inputs, 434 image_embeddings=image_embeddings, 435 multimask_output=multimask_output, 436 ) 437 438 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 439 # We use the dice loss over the masks as metric. 440 metric = mask_loss 441 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 442 443 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 444 445 def _validate_impl(self, forward_context): 446 self.model.eval() 447 448 input_check_done = False 449 450 val_iteration = 0 451 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 452 453 with torch.no_grad(): 454 for x, y in self.val_loader: 455 input_check_done = self._check_input_normalization(x, input_check_done) 456 457 with forward_context(): 458 (loss, mask_loss, iou_regression_loss, model_iou, 459 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 460 461 loss_val += loss.item() 462 metric_val += metric.item() 463 model_iou_val += model_iou.item() 464 val_iteration += 1 465 466 loss_val /= len(self.val_loader) 467 metric_val /= len(self.val_loader) 468 model_iou_val /= len(self.val_loader) 469 print() 470 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 471 472 if self.logger is not None: 473 self.logger.log_validation( 474 self._iteration, metric_val, loss_val, x, y, 475 sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val 476 ) 477 478 return metric_val
Trainer class for training the Segment Anything model.
This class is derived from torch_em.trainer.DefaultTrainer
.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSamInputs
can be used here. - n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
- n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
- mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. By default, set to the expected mse loss function.
- prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. Already allocated with the desired prompt generator by default.
- mask_prob: The probability of using the mask inputs in the iterative prompting (per
n_sub_iteration
). By default, set to '0.5'. - mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
- kwargs: The keyword arguments of the
DefaultTrainer
super class.
SamTrainer( convert_inputs: Callable, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
43 def __init__( 44 self, 45 convert_inputs: Callable, 46 n_sub_iteration: int, 47 n_objects_per_batch: Optional[int] = None, 48 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 49 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 50 mask_prob: float = 0.5, 51 mask_loss: Optional[torch.nn.Module] = None, 52 **kwargs 53 ): 54 if mask_loss is None: 55 # We have to use the Dice Loss with reduce channel set to None. 56 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 57 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 58 else: 59 self.mask_loss = mask_loss 60 61 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 62 self.convert_inputs = convert_inputs 63 self.mse_loss = mse_loss 64 self.n_objects_per_batch = n_objects_per_batch 65 self.n_sub_iteration = n_sub_iteration 66 self.prompt_generator = prompt_generator 67 self.mask_prob = mask_prob 68 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 69 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit