micro_sam.training.sam_trainer
1import os 2import time 3import random 4import warnings 5from typing import Optional, Callable 6 7import numpy as np 8 9import torch 10from torchvision.utils import make_grid 11 12import torch_em 13from torch_em.trainer.logger_base import TorchEmLogger 14 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator 16 17 18class SamTrainer(torch_em.trainer.DefaultTrainer): 19 """Trainer class for training the Segment Anything model. 20 21 This class is derived from `torch_em.trainer.DefaultTrainer`. 22 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 23 for details on its usage and implementation. 24 25 Args: 26 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 27 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 28 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 29 In each sub-iteration new point prompts are sampled where the model was wrong. 30 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 31 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 32 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 33 By default, set to the expected mse loss function. 34 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. 35 Already allocated with the desired prompt generator by default. 36 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`). 37 By default, set to '0.5'. 38 mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function. 39 kwargs: The keyword arguments of the `DefaultTrainer` super class. 40 """ 41 42 def __init__( 43 self, 44 convert_inputs: Callable, 45 n_sub_iteration: int, 46 n_objects_per_batch: Optional[int] = None, 47 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 48 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 49 mask_prob: float = 0.5, 50 mask_loss: Optional[torch.nn.Module] = None, 51 **kwargs 52 ): 53 if mask_loss is None: 54 # We have to use the Dice Loss with reduce channel set to None. 55 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 56 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 57 else: 58 self.mask_loss = mask_loss 59 60 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 61 self.convert_inputs = convert_inputs 62 self.mse_loss = mse_loss 63 self.n_objects_per_batch = n_objects_per_batch 64 self.n_sub_iteration = n_sub_iteration 65 self.prompt_generator = prompt_generator 66 self.mask_prob = mask_prob 67 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 68 self._kwargs = kwargs 69 70 def _get_prompt_and_multimasking_choices(self, current_iteration): 71 """Choose the type of prompts we sample for training, and then we call 72 'convert_inputs' with the correct prompting from here. 73 """ 74 if current_iteration % 2 == 0: # sample only a single point per object 75 n_pos, n_neg = 1, 0 76 get_boxes = False 77 multimask_output = True 78 79 else: # sample only a single box per object 80 n_pos, n_neg = 0, 0 81 get_boxes = True 82 multimask_output = False 83 84 return n_pos, n_neg, get_boxes, multimask_output 85 86 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 87 """Choose the type of prompts we sample for validation, and then we call 88 'convert_inputs' with the correct prompting from here. 89 """ 90 if current_iteration % 4 == 0: # sample only a single point per object 91 n_pos, n_neg = 1, 0 92 get_boxes = False 93 multimask_output = True 94 95 elif current_iteration % 4 == 1: # sample only a single box per object 96 n_pos, n_neg = 0, 0 97 get_boxes = True 98 multimask_output = False 99 100 elif current_iteration % 4 == 2: # sample a random no. of points 101 pos_range, neg_range = 4, 4 102 103 n_pos = np.random.randint(1, pos_range + 1) 104 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 105 n_neg = np.random.randint(1, neg_range + 1) 106 else: 107 n_neg = np.random.randint(0, neg_range + 1) 108 get_boxes = False 109 multimask_output = False 110 111 else: # sample boxes AND random no. of points 112 # here we can have (1, 0) because we also have box 113 pos_range, neg_range = 4, 4 114 115 n_pos = np.random.randint(1, pos_range + 1) 116 n_neg = np.random.randint(0, neg_range + 1) 117 get_boxes = True 118 multimask_output = False 119 120 return n_pos, n_neg, get_boxes, multimask_output 121 122 def _compute_iou(self, pred, true, eps=1e-7): 123 """Compute the IoU score between the prediction and target. 124 """ 125 pred_mask = pred > 0.5 # binarizing the output predictions 126 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 127 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 128 iou = overlap / (union + eps) 129 return iou 130 131 def _compute_loss(self, batched_outputs, y_one_hot): 132 """Compute the loss for one iteration. The loss is made up of two components: 133 - The mask loss: dice score between the predicted masks and targets. 134 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 135 """ 136 mask_loss, iou_regression_loss = 0.0, 0.0 137 batch_size = len(batched_outputs) 138 139 # Loop over the batch. 140 for batch_output, targets in zip(batched_outputs, y_one_hot): 141 142 predicted_objects = torch.sigmoid(batch_output["masks"]) 143 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 144 # We swap the axes that go into the dice loss so that the object axis 145 # corresponds to the channel axes. This ensures that the dice is computed 146 # independetly per channel. We do not reduce the channel axis in the dice, 147 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 148 dice_scores = torch.stack([ 149 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 150 for i in range(predicted_objects.shape[1]) 151 ]) 152 dice_scores, _ = torch.min(dice_scores, dim=0) 153 154 # Compute the actual IOU between the predicted and true objects. 155 # The outer loop is for the 1 or 3 predicted masks per true object. 156 with torch.no_grad(): 157 true_iou = torch.stack([ 158 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 159 ]) 160 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 161 # the object axis is back in the first dimension. 162 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 163 164 mask_loss = mask_loss + torch.mean(dice_scores) 165 iou_regression_loss = iou_regression_loss + iou_score 166 167 # Normalize by batch size so that loss/metric are comparable across batch sizes. 168 mask_loss = mask_loss / batch_size 169 iou_regression_loss = iou_regression_loss / batch_size 170 loss = mask_loss + iou_regression_loss 171 172 return loss, mask_loss, iou_regression_loss 173 174 # 175 # Functionality for iterative prompting loss 176 # 177 178 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 179 # Batched mask and logit (low-res mask) predictions. 180 masks = torch.stack([m["masks"] for m in batched_outputs]) 181 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 182 183 # Determine the best IOU across the multi-object prediction axis 184 # and turn this into a mask we can use for indexing. 185 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 186 # for details on the indexing logic. 187 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 188 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 189 190 # Index the mask and logits with the best iou indices. 191 # Note that we squash the first two axes (batch x objects) into one when indexing. 192 # That's why we need to reshape bax into (batch x objects) using a view. 193 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 194 batch_size, n_objects = masks.shape[:2] 195 h, w = masks.shape[-2:] 196 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 197 198 h, w = logits.shape[-2:] 199 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 200 201 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 202 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 203 masks = (masks > 0.0).float() 204 return masks, logits 205 206 def _use_mask_inputs(self, batched_inputs, y_one_hot): 207 # Whether to use masks per training top-iteration. 208 use_mask_inputs = False # determines if each sub-iteration will use mask inputs as prompts or not. 209 use_zero_mask = False # determines if the zeroth iteration will use zeros as mask inputs. 210 211 if self.mask_prob == 1: # i.e. always use masks. 212 use_mask_inputs = True # we would like to use mask inputs in all sub-iterations. 213 use_zero_mask = self.is_data_parallel # we would like to use zeros as mask inputs for zeroth iteration. 214 215 elif self.mask_prob > 0: # i.e. if we use mask inputs with a probability. 216 if self.is_data_parallel: # if training on multiple GPUs. 217 if torch.distributed.get_rank() == 0: # device with rank 0. 218 use_mask_inputs_tensor = torch.tensor( 219 random.random() < self.mask_prob, dtype=torch.uint8, device=self.device, 220 ) 221 else: # on other devices, we do not need this parameter at this stage. 222 use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device) 223 224 # Broadcast the value to all devices (ranks). 225 torch.distributed.broadcast(use_mask_inputs_tensor, src=0) 226 227 # And convert it back to our desired boolean value. 228 use_mask_inputs = bool(use_mask_inputs_tensor.item()) 229 use_zero_mask = use_mask_inputs # provides zeros as mask inputs. 230 else: # training on a single GPU. 231 use_mask_inputs = None 232 233 if use_zero_mask: 234 # We use zeros as mask inputs for the zeroth iteration. 235 y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256)) 236 237 # Add zeros as mask inputs to batched inputs. 238 for bi, curr_masks in zip(batched_inputs, y_zeros): 239 bi["mask_inputs"] = curr_masks 240 241 return batched_inputs, use_mask_inputs 242 243 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 244 """Compute the loss for several (sub-)iterations of iterative prompting. 245 In each iterations the prompts are updated based on the previous predictions. 246 """ 247 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 248 249 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 250 251 # Whether to use mask inputs in each sub-iteration. 252 batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot) 253 254 for i in range(0, num_subiter): 255 # We do multimasking only in the first sub-iteration as we then pass single prompt 256 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 257 batched_outputs = self.model( 258 batched_inputs=batched_inputs, 259 image_embeddings=image_embeddings, 260 multimask_output=multimask_output if i == 0 else False, 261 ) 262 263 # Compute loss for this sub-iteration. 264 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 265 266 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 267 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 268 with torch.no_grad(): 269 net_mean_model_iou = torch.mean(batched_iou_predictions) 270 271 loss = loss + net_loss 272 mask_loss = mask_loss + net_mask_loss 273 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 274 mean_model_iou = mean_model_iou + net_mean_model_iou 275 276 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 277 # Determine the next prompts based on current predictions. 278 with torch.no_grad(): 279 # Get the mask and logit predictions corresponding to the predicted object 280 # (per actual object) with the best IOU. 281 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 282 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs) 283 284 loss = loss / num_subiter 285 mask_loss = mask_loss / num_subiter 286 iou_regression_loss = iou_regression_loss / num_subiter 287 mean_model_iou = mean_model_iou / num_subiter 288 289 return loss, mask_loss, iou_regression_loss, mean_model_iou 290 291 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs): 292 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 293 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 294 # here, we get each object in the pairs and do the point choices per-object 295 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 296 297 # convert the point coordinates to the expected resolution for iterative prompting 298 # NOTE: 299 # - "only" need to transform the point prompts from the iterative prompting 300 # - the `logits` are the low res masks (256, 256), hence do not need the transform 301 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 302 303 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 304 if "point_coords" in _inp.keys() else net_coords 305 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 306 if "point_labels" in _inp.keys() else net_labels 307 308 _inp["point_coords"] = updated_point_coords 309 _inp["point_labels"] = updated_point_labels 310 311 if self.is_data_parallel: # multi-GPU training 312 use_mask_inputs_this_iter = use_mask_inputs 313 else: # single GPU training 314 if self.mask_prob > 0: 315 # using mask inputs for iterative prompting while training, with a probability 316 use_mask_inputs = (random.random() < self.mask_prob) 317 else: # otherwise we assume it is 0 and do not need the generator to decide. 318 use_mask_inputs = False 319 320 use_mask_inputs_this_iter = use_mask_inputs 321 322 if use_mask_inputs_this_iter: 323 _inp["mask_inputs"] = logits 324 else: # remove previously existing mask inputs to avoid using them in next sub-iteration. 325 _inp.pop("mask_inputs", None) 326 327 return batched_inputs 328 329 # 330 # Training Loop 331 # 332 333 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 334 """Compute one hot target (one mask per channel) for the sampled ids 335 and restrict the number of sampled objects to the minimal number in the batch. 336 """ 337 assert len(y) == len(sampled_ids) 338 339 # Get the minimal number of objects in this batch. 340 # The number of objects in a patch might be < n_objects_per_batch. 341 # This is why we need to restrict it here to ensure the same 342 # number of objects across the batch. 343 n_objects = min(len(ids) for ids in sampled_ids) 344 345 y = y.to(self.device, non_blocking=True) 346 # Compute the one hot targets for the seg-id. 347 y_one_hot = torch.stack([ 348 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 349 for target, ids in zip(y, sampled_ids) 350 ]).float() 351 352 # Also restrict the prompts to the number of objects. 353 batched_inputs = [ 354 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 355 for inp in batched_inputs 356 ] 357 return batched_inputs, y_one_hot 358 359 def _interactive_train_iteration(self, x, y): 360 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 361 362 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 363 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 364 365 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 366 batched_inputs=batched_inputs, 367 y_one_hot=y_one_hot, 368 num_subiter=self.n_sub_iteration, 369 multimask_output=multimask_output 370 ) 371 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 372 373 def _check_input_normalization(self, x, input_check_done): 374 # The expected data range of the SAM model is 8bit (0-255). 375 # It can easily happen that data is normalized beforehand in training. 376 # For some reasons we don't fully understand this still works, but it 377 # should still be avoided and is very detrimental in some settings 378 # (e.g. when freezing the image encoder) 379 # We check once per epoch if the data seems to be normalized already and 380 # raise a warning if this is the case. 381 if not input_check_done: 382 data_min, data_max = x.min(), x.max() 383 if (data_min < 0) or (data_max < 1): 384 warnings.warn( 385 "It looks like you are normalizing the training data. " 386 "The SAM model takes care of normalization, so it is better to not do this. " 387 "We recommend to remove data normalization and input data in the range [0, 255]." 388 ) 389 input_check_done = True 390 391 return input_check_done 392 393 def _train_epoch_impl(self, progress, forward_context, backprop): 394 self.model.train() 395 396 input_check_done = False 397 398 n_iter = 0 399 t_per_iter = time.time() 400 for x, y in self.train_loader: 401 input_check_done = self._check_input_normalization(x, input_check_done) 402 403 self.optimizer.zero_grad() 404 405 with forward_context(): 406 (loss, mask_loss, iou_regression_loss, model_iou, 407 sampled_binary_y) = self._interactive_train_iteration(x, y) 408 409 backprop(loss) 410 411 if self.logger is not None: 412 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 413 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 414 self.logger.log_train( 415 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 416 ) 417 418 self._iteration += 1 419 n_iter += 1 420 if self._iteration >= self.max_iteration: 421 break 422 progress.update(1) 423 424 t_per_iter = (time.time() - t_per_iter) / n_iter 425 return t_per_iter 426 427 def _interactive_val_iteration(self, x, y, val_iteration): 428 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 429 430 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 431 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 432 433 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 434 435 batched_outputs = self.model( 436 batched_inputs=batched_inputs, 437 image_embeddings=image_embeddings, 438 multimask_output=multimask_output, 439 ) 440 441 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 442 # We use the dice loss over the masks as metric. 443 metric = mask_loss 444 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 445 446 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 447 448 def _validate_impl(self, forward_context): 449 self.model.eval() 450 451 input_check_done = False 452 453 val_iteration = 0 454 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 455 mask_loss_val, iou_loss_val = 0.0, 0.0 456 457 with torch.no_grad(): 458 for x, y in self.val_loader: 459 input_check_done = self._check_input_normalization(x, input_check_done) 460 461 with forward_context(): 462 (loss, mask_loss, iou_regression_loss, model_iou, 463 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 464 465 loss_val += loss.item() 466 metric_val += metric.item() 467 mask_loss_val += mask_loss.item() 468 iou_loss_val += iou_regression_loss.item() 469 model_iou_val += model_iou.item() 470 val_iteration += 1 471 472 loss_val /= len(self.val_loader) 473 metric_val /= len(self.val_loader) 474 mask_loss_val /= len(self.val_loader) 475 iou_loss_val /= len(self.val_loader) 476 model_iou_val /= len(self.val_loader) 477 print() 478 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 479 480 if self.logger is not None: 481 self.logger.log_validation( 482 self._iteration, metric_val, loss_val, x, y, 483 sampled_binary_y, mask_loss_val, iou_loss_val, model_iou_val 484 ) 485 486 return metric_val 487 488 489class SamLogger(TorchEmLogger): 490 """@private""" 491 def __init__(self, trainer, save_root, **unused_kwargs): 492 super().__init__(trainer, save_root) 493 self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name) 494 os.makedirs(self.log_dir, exist_ok=True) 495 496 self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir) 497 self.log_image_interval = trainer.log_image_interval 498 499 def add_image(self, x, y, samples, name, step): 500 self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step) 501 self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step) 502 sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4) 503 self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step) 504 505 def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou): 506 self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) 507 self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step) 508 self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step) 509 self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step) 510 self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) 511 if step % self.log_image_interval == 0: 512 self.add_image(x, y, samples, "train", step) 513 514 def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou): 515 self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) 516 self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step) 517 self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step) 518 self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step) 519 self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) 520 self.add_image(x, y, samples, "validation", step)
class
SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
19class SamTrainer(torch_em.trainer.DefaultTrainer): 20 """Trainer class for training the Segment Anything model. 21 22 This class is derived from `torch_em.trainer.DefaultTrainer`. 23 Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py 24 for details on its usage and implementation. 25 26 Args: 27 convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. 28 The class `micro_sam.training.util.ConvertToSamInputs` can be used here. 29 n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. 30 In each sub-iteration new point prompts are sampled where the model was wrong. 31 n_objects_per_batch: If not given, we compute the loss for all objects in a sample. 32 Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled. 33 mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. 34 By default, set to the expected mse loss function. 35 prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. 36 Already allocated with the desired prompt generator by default. 37 mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`). 38 By default, set to '0.5'. 39 mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function. 40 kwargs: The keyword arguments of the `DefaultTrainer` super class. 41 """ 42 43 def __init__( 44 self, 45 convert_inputs: Callable, 46 n_sub_iteration: int, 47 n_objects_per_batch: Optional[int] = None, 48 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 49 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 50 mask_prob: float = 0.5, 51 mask_loss: Optional[torch.nn.Module] = None, 52 **kwargs 53 ): 54 if mask_loss is None: 55 # We have to use the Dice Loss with reduce channel set to None. 56 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 57 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 58 else: 59 self.mask_loss = mask_loss 60 61 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 62 self.convert_inputs = convert_inputs 63 self.mse_loss = mse_loss 64 self.n_objects_per_batch = n_objects_per_batch 65 self.n_sub_iteration = n_sub_iteration 66 self.prompt_generator = prompt_generator 67 self.mask_prob = mask_prob 68 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 69 self._kwargs = kwargs 70 71 def _get_prompt_and_multimasking_choices(self, current_iteration): 72 """Choose the type of prompts we sample for training, and then we call 73 'convert_inputs' with the correct prompting from here. 74 """ 75 if current_iteration % 2 == 0: # sample only a single point per object 76 n_pos, n_neg = 1, 0 77 get_boxes = False 78 multimask_output = True 79 80 else: # sample only a single box per object 81 n_pos, n_neg = 0, 0 82 get_boxes = True 83 multimask_output = False 84 85 return n_pos, n_neg, get_boxes, multimask_output 86 87 def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): 88 """Choose the type of prompts we sample for validation, and then we call 89 'convert_inputs' with the correct prompting from here. 90 """ 91 if current_iteration % 4 == 0: # sample only a single point per object 92 n_pos, n_neg = 1, 0 93 get_boxes = False 94 multimask_output = True 95 96 elif current_iteration % 4 == 1: # sample only a single box per object 97 n_pos, n_neg = 0, 0 98 get_boxes = True 99 multimask_output = False 100 101 elif current_iteration % 4 == 2: # sample a random no. of points 102 pos_range, neg_range = 4, 4 103 104 n_pos = np.random.randint(1, pos_range + 1) 105 if n_pos == 1: # to avoid (1, 0) combination for redundancy but still have (n_pos, 0) 106 n_neg = np.random.randint(1, neg_range + 1) 107 else: 108 n_neg = np.random.randint(0, neg_range + 1) 109 get_boxes = False 110 multimask_output = False 111 112 else: # sample boxes AND random no. of points 113 # here we can have (1, 0) because we also have box 114 pos_range, neg_range = 4, 4 115 116 n_pos = np.random.randint(1, pos_range + 1) 117 n_neg = np.random.randint(0, neg_range + 1) 118 get_boxes = True 119 multimask_output = False 120 121 return n_pos, n_neg, get_boxes, multimask_output 122 123 def _compute_iou(self, pred, true, eps=1e-7): 124 """Compute the IoU score between the prediction and target. 125 """ 126 pred_mask = pred > 0.5 # binarizing the output predictions 127 overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3)) 128 union = pred_mask.logical_or(true).sum(dim=(1, 2, 3)) 129 iou = overlap / (union + eps) 130 return iou 131 132 def _compute_loss(self, batched_outputs, y_one_hot): 133 """Compute the loss for one iteration. The loss is made up of two components: 134 - The mask loss: dice score between the predicted masks and targets. 135 - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target. 136 """ 137 mask_loss, iou_regression_loss = 0.0, 0.0 138 batch_size = len(batched_outputs) 139 140 # Loop over the batch. 141 for batch_output, targets in zip(batched_outputs, y_one_hot): 142 143 predicted_objects = torch.sigmoid(batch_output["masks"]) 144 # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop). 145 # We swap the axes that go into the dice loss so that the object axis 146 # corresponds to the channel axes. This ensures that the dice is computed 147 # independetly per channel. We do not reduce the channel axis in the dice, 148 # so that we can take the minimum (best score) of the 1/3 predicted masks per object. 149 dice_scores = torch.stack([ 150 self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1)) 151 for i in range(predicted_objects.shape[1]) 152 ]) 153 dice_scores, _ = torch.min(dice_scores, dim=0) 154 155 # Compute the actual IOU between the predicted and true objects. 156 # The outer loop is for the 1 or 3 predicted masks per true object. 157 with torch.no_grad(): 158 true_iou = torch.stack([ 159 self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1]) 160 ]) 161 # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that 162 # the object axis is back in the first dimension. 163 iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"]) 164 165 mask_loss = mask_loss + torch.mean(dice_scores) 166 iou_regression_loss = iou_regression_loss + iou_score 167 168 # Normalize by batch size so that loss/metric are comparable across batch sizes. 169 mask_loss = mask_loss / batch_size 170 iou_regression_loss = iou_regression_loss / batch_size 171 loss = mask_loss + iou_regression_loss 172 173 return loss, mask_loss, iou_regression_loss 174 175 # 176 # Functionality for iterative prompting loss 177 # 178 179 def _get_best_masks(self, batched_outputs, batched_iou_predictions): 180 # Batched mask and logit (low-res mask) predictions. 181 masks = torch.stack([m["masks"] for m in batched_outputs]) 182 logits = torch.stack([m["low_res_masks"] for m in batched_outputs]) 183 184 # Determine the best IOU across the multi-object prediction axis 185 # and turn this into a mask we can use for indexing. 186 # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax 187 # for details on the indexing logic. 188 best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True) 189 best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool() 190 191 # Index the mask and logits with the best iou indices. 192 # Note that we squash the first two axes (batch x objects) into one when indexing. 193 # That's why we need to reshape bax into (batch x objects) using a view. 194 # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...) 195 batch_size, n_objects = masks.shape[:2] 196 h, w = masks.shape[-2:] 197 masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w) 198 199 h, w = logits.shape[-2:] 200 logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w) 201 202 # Binarize the mask. Note that the mask here also contains logits, so we use 0.0 203 # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid) 204 masks = (masks > 0.0).float() 205 return masks, logits 206 207 def _use_mask_inputs(self, batched_inputs, y_one_hot): 208 # Whether to use masks per training top-iteration. 209 use_mask_inputs = False # determines if each sub-iteration will use mask inputs as prompts or not. 210 use_zero_mask = False # determines if the zeroth iteration will use zeros as mask inputs. 211 212 if self.mask_prob == 1: # i.e. always use masks. 213 use_mask_inputs = True # we would like to use mask inputs in all sub-iterations. 214 use_zero_mask = self.is_data_parallel # we would like to use zeros as mask inputs for zeroth iteration. 215 216 elif self.mask_prob > 0: # i.e. if we use mask inputs with a probability. 217 if self.is_data_parallel: # if training on multiple GPUs. 218 if torch.distributed.get_rank() == 0: # device with rank 0. 219 use_mask_inputs_tensor = torch.tensor( 220 random.random() < self.mask_prob, dtype=torch.uint8, device=self.device, 221 ) 222 else: # on other devices, we do not need this parameter at this stage. 223 use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device) 224 225 # Broadcast the value to all devices (ranks). 226 torch.distributed.broadcast(use_mask_inputs_tensor, src=0) 227 228 # And convert it back to our desired boolean value. 229 use_mask_inputs = bool(use_mask_inputs_tensor.item()) 230 use_zero_mask = use_mask_inputs # provides zeros as mask inputs. 231 else: # training on a single GPU. 232 use_mask_inputs = None 233 234 if use_zero_mask: 235 # We use zeros as mask inputs for the zeroth iteration. 236 y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256)) 237 238 # Add zeros as mask inputs to batched inputs. 239 for bi, curr_masks in zip(batched_inputs, y_zeros): 240 bi["mask_inputs"] = curr_masks 241 242 return batched_inputs, use_mask_inputs 243 244 def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output): 245 """Compute the loss for several (sub-)iterations of iterative prompting. 246 In each iterations the prompts are updated based on the previous predictions. 247 """ 248 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 249 250 loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0 251 252 # Whether to use mask inputs in each sub-iteration. 253 batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot) 254 255 for i in range(0, num_subiter): 256 # We do multimasking only in the first sub-iteration as we then pass single prompt 257 # after the first sub-iteration, we don't do multimasking because we get multiple prompts. 258 batched_outputs = self.model( 259 batched_inputs=batched_inputs, 260 image_embeddings=image_embeddings, 261 multimask_output=multimask_output if i == 0 else False, 262 ) 263 264 # Compute loss for this sub-iteration. 265 net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 266 267 # Compute the mean IOU predicted by the model. We keep track of this in the logger. 268 batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs]) 269 with torch.no_grad(): 270 net_mean_model_iou = torch.mean(batched_iou_predictions) 271 272 loss = loss + net_loss 273 mask_loss = mask_loss + net_mask_loss 274 iou_regression_loss = iou_regression_loss + net_iou_regression_loss 275 mean_model_iou = mean_model_iou + net_mean_model_iou 276 277 if i < (num_subiter - 1): # We need not update the prompts for the last iteration. 278 # Determine the next prompts based on current predictions. 279 with torch.no_grad(): 280 # Get the mask and logit predictions corresponding to the predicted object 281 # (per actual object) with the best IOU. 282 masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) 283 batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs) 284 285 loss = loss / num_subiter 286 mask_loss = mask_loss / num_subiter 287 iou_regression_loss = iou_regression_loss / num_subiter 288 mean_model_iou = mean_model_iou / num_subiter 289 290 return loss, mask_loss, iou_regression_loss, mean_model_iou 291 292 def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs): 293 # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs") 294 for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks): 295 # here, we get each object in the pairs and do the point choices per-object 296 net_coords, net_labels, _, _ = self.prompt_generator(x2, x1) 297 298 # convert the point coordinates to the expected resolution for iterative prompting 299 # NOTE: 300 # - "only" need to transform the point prompts from the iterative prompting 301 # - the `logits` are the low res masks (256, 256), hence do not need the transform 302 net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:]) 303 304 updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \ 305 if "point_coords" in _inp.keys() else net_coords 306 updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \ 307 if "point_labels" in _inp.keys() else net_labels 308 309 _inp["point_coords"] = updated_point_coords 310 _inp["point_labels"] = updated_point_labels 311 312 if self.is_data_parallel: # multi-GPU training 313 use_mask_inputs_this_iter = use_mask_inputs 314 else: # single GPU training 315 if self.mask_prob > 0: 316 # using mask inputs for iterative prompting while training, with a probability 317 use_mask_inputs = (random.random() < self.mask_prob) 318 else: # otherwise we assume it is 0 and do not need the generator to decide. 319 use_mask_inputs = False 320 321 use_mask_inputs_this_iter = use_mask_inputs 322 323 if use_mask_inputs_this_iter: 324 _inp["mask_inputs"] = logits 325 else: # remove previously existing mask inputs to avoid using them in next sub-iteration. 326 _inp.pop("mask_inputs", None) 327 328 return batched_inputs 329 330 # 331 # Training Loop 332 # 333 334 def _preprocess_batch(self, batched_inputs, y, sampled_ids): 335 """Compute one hot target (one mask per channel) for the sampled ids 336 and restrict the number of sampled objects to the minimal number in the batch. 337 """ 338 assert len(y) == len(sampled_ids) 339 340 # Get the minimal number of objects in this batch. 341 # The number of objects in a patch might be < n_objects_per_batch. 342 # This is why we need to restrict it here to ensure the same 343 # number of objects across the batch. 344 n_objects = min(len(ids) for ids in sampled_ids) 345 346 y = y.to(self.device, non_blocking=True) 347 # Compute the one hot targets for the seg-id. 348 y_one_hot = torch.stack([ 349 torch.stack([target == seg_id for seg_id in ids[:n_objects]]) 350 for target, ids in zip(y, sampled_ids) 351 ]).float() 352 353 # Also restrict the prompts to the number of objects. 354 batched_inputs = [ 355 {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()} 356 for inp in batched_inputs 357 ] 358 return batched_inputs, y_one_hot 359 360 def _interactive_train_iteration(self, x, y): 361 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration) 362 363 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 364 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 365 366 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss( 367 batched_inputs=batched_inputs, 368 y_one_hot=y_one_hot, 369 num_subiter=self.n_sub_iteration, 370 multimask_output=multimask_output 371 ) 372 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot 373 374 def _check_input_normalization(self, x, input_check_done): 375 # The expected data range of the SAM model is 8bit (0-255). 376 # It can easily happen that data is normalized beforehand in training. 377 # For some reasons we don't fully understand this still works, but it 378 # should still be avoided and is very detrimental in some settings 379 # (e.g. when freezing the image encoder) 380 # We check once per epoch if the data seems to be normalized already and 381 # raise a warning if this is the case. 382 if not input_check_done: 383 data_min, data_max = x.min(), x.max() 384 if (data_min < 0) or (data_max < 1): 385 warnings.warn( 386 "It looks like you are normalizing the training data. " 387 "The SAM model takes care of normalization, so it is better to not do this. " 388 "We recommend to remove data normalization and input data in the range [0, 255]." 389 ) 390 input_check_done = True 391 392 return input_check_done 393 394 def _train_epoch_impl(self, progress, forward_context, backprop): 395 self.model.train() 396 397 input_check_done = False 398 399 n_iter = 0 400 t_per_iter = time.time() 401 for x, y in self.train_loader: 402 input_check_done = self._check_input_normalization(x, input_check_done) 403 404 self.optimizer.zero_grad() 405 406 with forward_context(): 407 (loss, mask_loss, iou_regression_loss, model_iou, 408 sampled_binary_y) = self._interactive_train_iteration(x, y) 409 410 backprop(loss) 411 412 if self.logger is not None: 413 lr = [pm["lr"] for pm in self.optimizer.param_groups][0] 414 samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None 415 self.logger.log_train( 416 self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou 417 ) 418 419 self._iteration += 1 420 n_iter += 1 421 if self._iteration >= self.max_iteration: 422 break 423 progress.update(1) 424 425 t_per_iter = (time.time() - t_per_iter) / n_iter 426 return t_per_iter 427 428 def _interactive_val_iteration(self, x, y, val_iteration): 429 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration) 430 431 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch) 432 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids) 433 434 image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) 435 436 batched_outputs = self.model( 437 batched_inputs=batched_inputs, 438 image_embeddings=image_embeddings, 439 multimask_output=multimask_output, 440 ) 441 442 loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot) 443 # We use the dice loss over the masks as metric. 444 metric = mask_loss 445 model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs])) 446 447 return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric 448 449 def _validate_impl(self, forward_context): 450 self.model.eval() 451 452 input_check_done = False 453 454 val_iteration = 0 455 metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0 456 mask_loss_val, iou_loss_val = 0.0, 0.0 457 458 with torch.no_grad(): 459 for x, y in self.val_loader: 460 input_check_done = self._check_input_normalization(x, input_check_done) 461 462 with forward_context(): 463 (loss, mask_loss, iou_regression_loss, model_iou, 464 sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration) 465 466 loss_val += loss.item() 467 metric_val += metric.item() 468 mask_loss_val += mask_loss.item() 469 iou_loss_val += iou_regression_loss.item() 470 model_iou_val += model_iou.item() 471 val_iteration += 1 472 473 loss_val /= len(self.val_loader) 474 metric_val /= len(self.val_loader) 475 mask_loss_val /= len(self.val_loader) 476 iou_loss_val /= len(self.val_loader) 477 model_iou_val /= len(self.val_loader) 478 print() 479 print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}") 480 481 if self.logger is not None: 482 self.logger.log_validation( 483 self._iteration, metric_val, loss_val, x, y, 484 sampled_binary_y, mask_loss_val, iou_loss_val, model_iou_val 485 ) 486 487 return metric_val
Trainer class for training the Segment Anything model.
This class is derived from torch_em.trainer.DefaultTrainer.
Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
for details on its usage and implementation.
Arguments:
- convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
The class
micro_sam.training.util.ConvertToSamInputscan be used here. - n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
- n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
- mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. By default, set to the expected mse loss function.
- prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. Already allocated with the desired prompt generator by default.
- mask_prob: The probability of using the mask inputs in the iterative prompting (per
n_sub_iteration). By default, set to '0.5'. - mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
- kwargs: The keyword arguments of the
DefaultTrainersuper class.
SamTrainer( convert_inputs: Callable, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
43 def __init__( 44 self, 45 convert_inputs: Callable, 46 n_sub_iteration: int, 47 n_objects_per_batch: Optional[int] = None, 48 mse_loss: torch.nn.Module = torch.nn.MSELoss(), 49 prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), 50 mask_prob: float = 0.5, 51 mask_loss: Optional[torch.nn.Module] = None, 52 **kwargs 53 ): 54 if mask_loss is None: 55 # We have to use the Dice Loss with reduce channel set to None. 56 # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. 57 self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) 58 else: 59 self.mask_loss = mask_loss 60 61 super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) 62 self.convert_inputs = convert_inputs 63 self.mse_loss = mse_loss 64 self.n_objects_per_batch = n_objects_per_batch 65 self.n_sub_iteration = n_sub_iteration 66 self.prompt_generator = prompt_generator 67 self.mask_prob = mask_prob 68 self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized() 69 self._kwargs = kwargs
Inherited Members
- torch_em.trainer.default_trainer.DefaultTrainer
- name
- id_
- train_loader
- val_loader
- model
- loss
- optimizer
- metric
- device
- lr_scheduler
- log_image_interval
- save_root
- compile_model
- rank
- mixed_precision
- early_stopping
- train_time
- logger_class
- logger_kwargs
- checkpoint_folder
- iteration
- epoch
- Deserializer
- Serializer
- fit