micro_sam.training.sam_trainer

  1import os
  2import time
  3import random
  4import warnings
  5from typing import Optional
  6
  7import numpy as np
  8
  9import torch
 10from torchvision.utils import make_grid
 11
 12import torch_em
 13from torch_em.trainer.logger_base import TorchEmLogger
 14
 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
 16
 17
 18class SamTrainer(torch_em.trainer.DefaultTrainer):
 19    """Trainer class for training the Segment Anything model.
 20
 21    This class is derived from `torch_em.trainer.DefaultTrainer`.
 22    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 23    for details on its usage and implementation.
 24
 25    Args:
 26        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 27            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 28        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 29            In each sub-iteration new point prompts are sampled where the model was wrong.
 30        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 31            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 32        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 33        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
 34        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
 35        mask_loss: The loss to compare the predicted masks and the targets.
 36        kwargs: The keyword arguments of the DefaultTrainer super class.
 37    """
 38
 39    def __init__(
 40        self,
 41        convert_inputs,
 42        n_sub_iteration: int,
 43        n_objects_per_batch: Optional[int] = None,
 44        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 45        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 46        mask_prob: float = 0.5,
 47        mask_loss: Optional[torch.nn.Module] = None,
 48        **kwargs
 49    ):
 50        if mask_loss is None:
 51            # We have to use the Dice Loss with reduce channel set to None.
 52            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 53            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 54        else:
 55            self.mask_loss = mask_loss
 56
 57        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 58        self.convert_inputs = convert_inputs
 59        self.mse_loss = mse_loss
 60        self.n_objects_per_batch = n_objects_per_batch
 61        self.n_sub_iteration = n_sub_iteration
 62        self.prompt_generator = prompt_generator
 63        self.mask_prob = mask_prob
 64        self._kwargs = kwargs
 65
 66    def _get_prompt_and_multimasking_choices(self, current_iteration):
 67        """Choose the type of prompts we sample for training, and then we call
 68        'convert_inputs' with the correct prompting from here.
 69        """
 70        if current_iteration % 2 == 0:  # sample only a single point per object
 71            n_pos, n_neg = 1, 0
 72            get_boxes = False
 73            multimask_output = True
 74
 75        else:  # sample only a single box per object
 76            n_pos, n_neg = 0, 0
 77            get_boxes = True
 78            multimask_output = False
 79
 80        return n_pos, n_neg, get_boxes, multimask_output
 81
 82    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 83        """Choose the type of prompts we sample for validation, and then we call
 84        'convert_inputs' with the correct prompting from here.
 85        """
 86        if current_iteration % 4 == 0:  # sample only a single point per object
 87            n_pos, n_neg = 1, 0
 88            get_boxes = False
 89            multimask_output = True
 90
 91        elif current_iteration % 4 == 1:  # sample only a single box per object
 92            n_pos, n_neg = 0, 0
 93            get_boxes = True
 94            multimask_output = False
 95
 96        elif current_iteration % 4 == 2:  # sample a random no. of points
 97            pos_range, neg_range = 4, 4
 98
 99            n_pos = np.random.randint(1, pos_range + 1)
100            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
101                n_neg = np.random.randint(1, neg_range + 1)
102            else:
103                n_neg = np.random.randint(0, neg_range + 1)
104            get_boxes = False
105            multimask_output = False
106
107        else:  # sample boxes AND random no. of points
108            # here we can have (1, 0) because we also have box
109            pos_range, neg_range = 4, 4
110
111            n_pos = np.random.randint(1, pos_range + 1)
112            n_neg = np.random.randint(0, neg_range + 1)
113            get_boxes = True
114            multimask_output = False
115
116        return n_pos, n_neg, get_boxes, multimask_output
117
118    def _compute_iou(self, pred, true, eps=1e-7):
119        """Compute the IoU score between the prediction and target.
120        """
121        pred_mask = pred > 0.5  # binarizing the output predictions
122        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
123        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
124        iou = overlap / (union + eps)
125        return iou
126
127    def _compute_loss(self, batched_outputs, y_one_hot):
128        """Compute the loss for one iteration. The loss is made up of two components:
129        - The mask loss: dice score between the predicted masks and targets.
130        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
131        """
132        mask_loss, iou_regression_loss = 0.0, 0.0
133
134        # Loop over the batch.
135        for batch_output, targets in zip(batched_outputs, y_one_hot):
136
137            predicted_objects = torch.sigmoid(batch_output["masks"])
138            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
139            # We swap the axes that go into the dice loss so that the object axis
140            # corresponds to the channel axes. This ensures that the dice is computed
141            # independetly per channel. We do not reduce the channel axis in the dice,
142            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
143            dice_scores = torch.stack([
144                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
145                for i in range(predicted_objects.shape[1])
146            ])
147            dice_scores, _ = torch.min(dice_scores, dim=0)
148
149            # Compute the actual IOU between the predicted and true objects.
150            # The outer loop is for the 1 or 3 predicted masks per true object.
151            with torch.no_grad():
152                true_iou = torch.stack([
153                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
154                ])
155            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
156            # the object axis is back in the first dimension.
157            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
158
159            mask_loss = mask_loss + torch.mean(dice_scores)
160            iou_regression_loss = iou_regression_loss + iou_score
161
162        loss = mask_loss + iou_regression_loss
163
164        return loss, mask_loss, iou_regression_loss
165
166    #
167    # Functionality for iterative prompting loss
168    #
169
170    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
171        # Batched mask and logit (low-res mask) predictions.
172        masks = torch.stack([m["masks"] for m in batched_outputs])
173        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
174
175        # Determine the best IOU across the multi-object prediction axis
176        # and turn this into a mask we can use for indexing.
177        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
178        # for details on the indexing logic.
179        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
180        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
181
182        # Index the mask and logits with the best iou indices.
183        # Note that we squash the first two axes (batch x objects) into one when indexing.
184        # That's why we need to reshape bax into (batch x objects) using a view.
185        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
186        batch_size, n_objects = masks.shape[:2]
187        h, w = masks.shape[-2:]
188        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
189
190        h, w = logits.shape[-2:]
191        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
192
193        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
194        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
195        masks = (masks > 0.0).float()
196        return masks, logits
197
198    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
199        """Compute the loss for several (sub-)iterations of iterative prompting.
200        In each iterations the prompts are updated based on the previous predictions.
201        """
202        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
203
204        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
205
206        for i in range(0, num_subiter):
207            # We do multimasking only in the first sub-iteration as we then pass single prompt
208            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
209            batched_outputs = self.model(batched_inputs,
210                                         image_embeddings=image_embeddings,
211                                         multimask_output=multimask_output if i == 0 else False)
212
213            # Compute loss for tis sub-iteration.
214            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
215
216            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
217            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
218            with torch.no_grad():
219                net_mean_model_iou = torch.mean(batched_iou_predictions)
220
221            loss += net_loss
222            mask_loss += net_mask_loss
223            iou_regression_loss += net_iou_regression_loss
224            mean_model_iou += net_mean_model_iou
225
226            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
227                # Determine the next prompts based on current predictions.
228                with torch.no_grad():
229                    # Get the mask and logit predictions corresponding to the predicted object
230                    # (per actual object) with the best IOU.
231                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
232                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
233
234        loss = loss / num_subiter
235        mask_loss = mask_loss / num_subiter
236        iou_regression_loss = iou_regression_loss / num_subiter
237        mean_model_iou = mean_model_iou / num_subiter
238
239        return loss, mask_loss, iou_regression_loss, mean_model_iou
240
241    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
242        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
243        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
244            # here, we get each object in the pairs and do the point choices per-object
245            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
246
247            # convert the point coordinates to the expected resolution for iterative prompting
248            # NOTE:
249            #   - "only" need to transform the point prompts from the iterative prompting
250            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
251            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
252
253            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
254                if "point_coords" in _inp.keys() else net_coords
255            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
256                if "point_labels" in _inp.keys() else net_labels
257
258            _inp["point_coords"] = updated_point_coords
259            _inp["point_labels"] = updated_point_labels
260
261            if self.mask_prob > 0:
262                # using mask inputs for iterative prompting while training, with a probability
263                use_mask_inputs = (random.random() < self.mask_prob)
264                if use_mask_inputs:
265                    _inp["mask_inputs"] = logits
266                else:  # remove  previously existing mask inputs to avoid using them in next sub-iteration
267                    _inp.pop("mask_inputs", None)
268
269        return batched_inputs
270
271    #
272    # Training Loop
273    #
274
275    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
276        """Compute one hot target (one mask per channel) for the sampled ids
277        and restrict the number of sampled objects to the minimal number in the batch.
278        """
279        assert len(y) == len(sampled_ids)
280
281        # Get the minimal number of objects in this batch.
282        # The number of objects in a patch might be < n_objects_per_batch.
283        # This is why we need to restrict it here to ensure the same
284        # number of objects across the batch.
285        n_objects = min(len(ids) for ids in sampled_ids)
286
287        y = y.to(self.device, non_blocking=True)
288        # Compute the one hot targets for the seg-id.
289        y_one_hot = torch.stack([
290            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
291            for target, ids in zip(y, sampled_ids)
292        ]).float()
293
294        # Also restrict the prompts to the number of objects.
295        batched_inputs = [
296            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
297            for inp in batched_inputs
298        ]
299        return batched_inputs, y_one_hot
300
301    def _interactive_train_iteration(self, x, y):
302        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
303
304        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
305        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
306
307        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
308            batched_inputs, y_one_hot,
309            num_subiter=self.n_sub_iteration, multimask_output=multimask_output
310        )
311        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
312
313    def _check_input_normalization(self, x, input_check_done):
314        # The expected data range of the SAM model is 8bit (0-255).
315        # It can easily happen that data is normalized beforehand in training.
316        # For some reasons we don't fully understand this still works, but it
317        # should still be avoided and is very detrimental in some settings
318        # (e.g. when freezing the image encoder)
319        # We check once per epoch if the data seems to be normalized already and
320        # raise a warning if this is the case.
321        if not input_check_done:
322            data_min, data_max = x.min(), x.max()
323            if (data_min < 0) or (data_max < 1):
324                warnings.warn(
325                    "It looks like you are normalizing the training data."
326                    "The SAM model takes care of normalization, so it is better to not do this."
327                    "We recommend to remove data normalization and input data in the range [0, 255]."
328                )
329            input_check_done = True
330
331        return input_check_done
332
333    def _train_epoch_impl(self, progress, forward_context, backprop):
334        self.model.train()
335
336        input_check_done = False
337
338        n_iter = 0
339        t_per_iter = time.time()
340        for x, y in self.train_loader:
341            input_check_done = self._check_input_normalization(x, input_check_done)
342
343            self.optimizer.zero_grad()
344
345            with forward_context():
346                (loss, mask_loss, iou_regression_loss, model_iou,
347                 sampled_binary_y) = self._interactive_train_iteration(x, y)
348
349            backprop(loss)
350
351            if self.logger is not None:
352                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
353                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
354                self.logger.log_train(self._iteration, loss, lr, x, y, samples,
355                                      mask_loss, iou_regression_loss, model_iou)
356
357            self._iteration += 1
358            n_iter += 1
359            if self._iteration >= self.max_iteration:
360                break
361            progress.update(1)
362
363        t_per_iter = (time.time() - t_per_iter) / n_iter
364        return t_per_iter
365
366    def _interactive_val_iteration(self, x, y, val_iteration):
367        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
368
369        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
370        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
371
372        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
373
374        batched_outputs = self.model(
375            batched_inputs,
376            image_embeddings=image_embeddings,
377            multimask_output=multimask_output,
378        )
379
380        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
381        # We use the dice loss over the masks as metric.
382        metric = mask_loss
383        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
384
385        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
386
387    def _validate_impl(self, forward_context):
388        self.model.eval()
389
390        input_check_done = False
391
392        val_iteration = 0
393        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
394
395        with torch.no_grad():
396            for x, y in self.val_loader:
397                input_check_done = self._check_input_normalization(x, input_check_done)
398
399                with forward_context():
400                    (loss, mask_loss, iou_regression_loss, model_iou,
401                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
402
403                loss_val += loss.item()
404                metric_val += metric.item()
405                model_iou_val += model_iou.item()
406                val_iteration += 1
407
408        loss_val /= len(self.val_loader)
409        metric_val /= len(self.val_loader)
410        model_iou_val /= len(self.val_loader)
411        print()
412        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
413
414        if self.logger is not None:
415            self.logger.log_validation(
416                self._iteration, metric_val, loss_val, x, y,
417                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
418            )
419
420        return metric_val
421
422
423class SamLogger(TorchEmLogger):
424    """@private"""
425    def __init__(self, trainer, save_root, **unused_kwargs):
426        super().__init__(trainer, save_root)
427        self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
428            os.path.join(save_root, "logs", trainer.name)
429        os.makedirs(self.log_dir, exist_ok=True)
430
431        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
432        self.log_image_interval = trainer.log_image_interval
433
434    def add_image(self, x, y, samples, name, step):
435        self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
436        self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
437        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
438        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
439
440    def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
441        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
442        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
443        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
444        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
445        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
446        if step % self.log_image_interval == 0:
447            self.add_image(x, y, samples, "train", step)
448
449    def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
450        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
451        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
452        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
453        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
454        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
455        self.add_image(x, y, samples, "validation", step)
class SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 19class SamTrainer(torch_em.trainer.DefaultTrainer):
 20    """Trainer class for training the Segment Anything model.
 21
 22    This class is derived from `torch_em.trainer.DefaultTrainer`.
 23    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 24    for details on its usage and implementation.
 25
 26    Args:
 27        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 28            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 29        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 30            In each sub-iteration new point prompts are sampled where the model was wrong.
 31        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 32            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 33        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 34        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
 35        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
 36        mask_loss: The loss to compare the predicted masks and the targets.
 37        kwargs: The keyword arguments of the DefaultTrainer super class.
 38    """
 39
 40    def __init__(
 41        self,
 42        convert_inputs,
 43        n_sub_iteration: int,
 44        n_objects_per_batch: Optional[int] = None,
 45        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 46        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 47        mask_prob: float = 0.5,
 48        mask_loss: Optional[torch.nn.Module] = None,
 49        **kwargs
 50    ):
 51        if mask_loss is None:
 52            # We have to use the Dice Loss with reduce channel set to None.
 53            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 54            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 55        else:
 56            self.mask_loss = mask_loss
 57
 58        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 59        self.convert_inputs = convert_inputs
 60        self.mse_loss = mse_loss
 61        self.n_objects_per_batch = n_objects_per_batch
 62        self.n_sub_iteration = n_sub_iteration
 63        self.prompt_generator = prompt_generator
 64        self.mask_prob = mask_prob
 65        self._kwargs = kwargs
 66
 67    def _get_prompt_and_multimasking_choices(self, current_iteration):
 68        """Choose the type of prompts we sample for training, and then we call
 69        'convert_inputs' with the correct prompting from here.
 70        """
 71        if current_iteration % 2 == 0:  # sample only a single point per object
 72            n_pos, n_neg = 1, 0
 73            get_boxes = False
 74            multimask_output = True
 75
 76        else:  # sample only a single box per object
 77            n_pos, n_neg = 0, 0
 78            get_boxes = True
 79            multimask_output = False
 80
 81        return n_pos, n_neg, get_boxes, multimask_output
 82
 83    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 84        """Choose the type of prompts we sample for validation, and then we call
 85        'convert_inputs' with the correct prompting from here.
 86        """
 87        if current_iteration % 4 == 0:  # sample only a single point per object
 88            n_pos, n_neg = 1, 0
 89            get_boxes = False
 90            multimask_output = True
 91
 92        elif current_iteration % 4 == 1:  # sample only a single box per object
 93            n_pos, n_neg = 0, 0
 94            get_boxes = True
 95            multimask_output = False
 96
 97        elif current_iteration % 4 == 2:  # sample a random no. of points
 98            pos_range, neg_range = 4, 4
 99
100            n_pos = np.random.randint(1, pos_range + 1)
101            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
102                n_neg = np.random.randint(1, neg_range + 1)
103            else:
104                n_neg = np.random.randint(0, neg_range + 1)
105            get_boxes = False
106            multimask_output = False
107
108        else:  # sample boxes AND random no. of points
109            # here we can have (1, 0) because we also have box
110            pos_range, neg_range = 4, 4
111
112            n_pos = np.random.randint(1, pos_range + 1)
113            n_neg = np.random.randint(0, neg_range + 1)
114            get_boxes = True
115            multimask_output = False
116
117        return n_pos, n_neg, get_boxes, multimask_output
118
119    def _compute_iou(self, pred, true, eps=1e-7):
120        """Compute the IoU score between the prediction and target.
121        """
122        pred_mask = pred > 0.5  # binarizing the output predictions
123        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
124        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
125        iou = overlap / (union + eps)
126        return iou
127
128    def _compute_loss(self, batched_outputs, y_one_hot):
129        """Compute the loss for one iteration. The loss is made up of two components:
130        - The mask loss: dice score between the predicted masks and targets.
131        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
132        """
133        mask_loss, iou_regression_loss = 0.0, 0.0
134
135        # Loop over the batch.
136        for batch_output, targets in zip(batched_outputs, y_one_hot):
137
138            predicted_objects = torch.sigmoid(batch_output["masks"])
139            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
140            # We swap the axes that go into the dice loss so that the object axis
141            # corresponds to the channel axes. This ensures that the dice is computed
142            # independetly per channel. We do not reduce the channel axis in the dice,
143            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
144            dice_scores = torch.stack([
145                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
146                for i in range(predicted_objects.shape[1])
147            ])
148            dice_scores, _ = torch.min(dice_scores, dim=0)
149
150            # Compute the actual IOU between the predicted and true objects.
151            # The outer loop is for the 1 or 3 predicted masks per true object.
152            with torch.no_grad():
153                true_iou = torch.stack([
154                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
155                ])
156            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
157            # the object axis is back in the first dimension.
158            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
159
160            mask_loss = mask_loss + torch.mean(dice_scores)
161            iou_regression_loss = iou_regression_loss + iou_score
162
163        loss = mask_loss + iou_regression_loss
164
165        return loss, mask_loss, iou_regression_loss
166
167    #
168    # Functionality for iterative prompting loss
169    #
170
171    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
172        # Batched mask and logit (low-res mask) predictions.
173        masks = torch.stack([m["masks"] for m in batched_outputs])
174        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
175
176        # Determine the best IOU across the multi-object prediction axis
177        # and turn this into a mask we can use for indexing.
178        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
179        # for details on the indexing logic.
180        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
181        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
182
183        # Index the mask and logits with the best iou indices.
184        # Note that we squash the first two axes (batch x objects) into one when indexing.
185        # That's why we need to reshape bax into (batch x objects) using a view.
186        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
187        batch_size, n_objects = masks.shape[:2]
188        h, w = masks.shape[-2:]
189        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
190
191        h, w = logits.shape[-2:]
192        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
193
194        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
195        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
196        masks = (masks > 0.0).float()
197        return masks, logits
198
199    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
200        """Compute the loss for several (sub-)iterations of iterative prompting.
201        In each iterations the prompts are updated based on the previous predictions.
202        """
203        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
204
205        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
206
207        for i in range(0, num_subiter):
208            # We do multimasking only in the first sub-iteration as we then pass single prompt
209            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
210            batched_outputs = self.model(batched_inputs,
211                                         image_embeddings=image_embeddings,
212                                         multimask_output=multimask_output if i == 0 else False)
213
214            # Compute loss for tis sub-iteration.
215            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
216
217            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
218            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
219            with torch.no_grad():
220                net_mean_model_iou = torch.mean(batched_iou_predictions)
221
222            loss += net_loss
223            mask_loss += net_mask_loss
224            iou_regression_loss += net_iou_regression_loss
225            mean_model_iou += net_mean_model_iou
226
227            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
228                # Determine the next prompts based on current predictions.
229                with torch.no_grad():
230                    # Get the mask and logit predictions corresponding to the predicted object
231                    # (per actual object) with the best IOU.
232                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
233                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
234
235        loss = loss / num_subiter
236        mask_loss = mask_loss / num_subiter
237        iou_regression_loss = iou_regression_loss / num_subiter
238        mean_model_iou = mean_model_iou / num_subiter
239
240        return loss, mask_loss, iou_regression_loss, mean_model_iou
241
242    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
243        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
244        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
245            # here, we get each object in the pairs and do the point choices per-object
246            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
247
248            # convert the point coordinates to the expected resolution for iterative prompting
249            # NOTE:
250            #   - "only" need to transform the point prompts from the iterative prompting
251            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
252            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
253
254            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
255                if "point_coords" in _inp.keys() else net_coords
256            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
257                if "point_labels" in _inp.keys() else net_labels
258
259            _inp["point_coords"] = updated_point_coords
260            _inp["point_labels"] = updated_point_labels
261
262            if self.mask_prob > 0:
263                # using mask inputs for iterative prompting while training, with a probability
264                use_mask_inputs = (random.random() < self.mask_prob)
265                if use_mask_inputs:
266                    _inp["mask_inputs"] = logits
267                else:  # remove  previously existing mask inputs to avoid using them in next sub-iteration
268                    _inp.pop("mask_inputs", None)
269
270        return batched_inputs
271
272    #
273    # Training Loop
274    #
275
276    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
277        """Compute one hot target (one mask per channel) for the sampled ids
278        and restrict the number of sampled objects to the minimal number in the batch.
279        """
280        assert len(y) == len(sampled_ids)
281
282        # Get the minimal number of objects in this batch.
283        # The number of objects in a patch might be < n_objects_per_batch.
284        # This is why we need to restrict it here to ensure the same
285        # number of objects across the batch.
286        n_objects = min(len(ids) for ids in sampled_ids)
287
288        y = y.to(self.device, non_blocking=True)
289        # Compute the one hot targets for the seg-id.
290        y_one_hot = torch.stack([
291            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
292            for target, ids in zip(y, sampled_ids)
293        ]).float()
294
295        # Also restrict the prompts to the number of objects.
296        batched_inputs = [
297            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
298            for inp in batched_inputs
299        ]
300        return batched_inputs, y_one_hot
301
302    def _interactive_train_iteration(self, x, y):
303        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
304
305        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
306        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
307
308        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
309            batched_inputs, y_one_hot,
310            num_subiter=self.n_sub_iteration, multimask_output=multimask_output
311        )
312        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
313
314    def _check_input_normalization(self, x, input_check_done):
315        # The expected data range of the SAM model is 8bit (0-255).
316        # It can easily happen that data is normalized beforehand in training.
317        # For some reasons we don't fully understand this still works, but it
318        # should still be avoided and is very detrimental in some settings
319        # (e.g. when freezing the image encoder)
320        # We check once per epoch if the data seems to be normalized already and
321        # raise a warning if this is the case.
322        if not input_check_done:
323            data_min, data_max = x.min(), x.max()
324            if (data_min < 0) or (data_max < 1):
325                warnings.warn(
326                    "It looks like you are normalizing the training data."
327                    "The SAM model takes care of normalization, so it is better to not do this."
328                    "We recommend to remove data normalization and input data in the range [0, 255]."
329                )
330            input_check_done = True
331
332        return input_check_done
333
334    def _train_epoch_impl(self, progress, forward_context, backprop):
335        self.model.train()
336
337        input_check_done = False
338
339        n_iter = 0
340        t_per_iter = time.time()
341        for x, y in self.train_loader:
342            input_check_done = self._check_input_normalization(x, input_check_done)
343
344            self.optimizer.zero_grad()
345
346            with forward_context():
347                (loss, mask_loss, iou_regression_loss, model_iou,
348                 sampled_binary_y) = self._interactive_train_iteration(x, y)
349
350            backprop(loss)
351
352            if self.logger is not None:
353                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
354                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
355                self.logger.log_train(self._iteration, loss, lr, x, y, samples,
356                                      mask_loss, iou_regression_loss, model_iou)
357
358            self._iteration += 1
359            n_iter += 1
360            if self._iteration >= self.max_iteration:
361                break
362            progress.update(1)
363
364        t_per_iter = (time.time() - t_per_iter) / n_iter
365        return t_per_iter
366
367    def _interactive_val_iteration(self, x, y, val_iteration):
368        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
369
370        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
371        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
372
373        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
374
375        batched_outputs = self.model(
376            batched_inputs,
377            image_embeddings=image_embeddings,
378            multimask_output=multimask_output,
379        )
380
381        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
382        # We use the dice loss over the masks as metric.
383        metric = mask_loss
384        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
385
386        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
387
388    def _validate_impl(self, forward_context):
389        self.model.eval()
390
391        input_check_done = False
392
393        val_iteration = 0
394        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
395
396        with torch.no_grad():
397            for x, y in self.val_loader:
398                input_check_done = self._check_input_normalization(x, input_check_done)
399
400                with forward_context():
401                    (loss, mask_loss, iou_regression_loss, model_iou,
402                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
403
404                loss_val += loss.item()
405                metric_val += metric.item()
406                model_iou_val += model_iou.item()
407                val_iteration += 1
408
409        loss_val /= len(self.val_loader)
410        metric_val /= len(self.val_loader)
411        model_iou_val /= len(self.val_loader)
412        print()
413        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
414
415        if self.logger is not None:
416            self.logger.log_validation(
417                self._iteration, metric_val, loss_val, x, y,
418                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
419            )
420
421        return metric_val

Trainer class for training the Segment Anything model.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSamInputs can be used here.
  • n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
  • n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
  • mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
  • prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
  • mask_prob: The probability of using the mask inputs in the iterative prompting (per n_sub_iteration)
  • mask_loss: The loss to compare the predicted masks and the targets.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
40    def __init__(
41        self,
42        convert_inputs,
43        n_sub_iteration: int,
44        n_objects_per_batch: Optional[int] = None,
45        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
46        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
47        mask_prob: float = 0.5,
48        mask_loss: Optional[torch.nn.Module] = None,
49        **kwargs
50    ):
51        if mask_loss is None:
52            # We have to use the Dice Loss with reduce channel set to None.
53            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
54            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
55        else:
56            self.mask_loss = mask_loss
57
58        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
59        self.convert_inputs = convert_inputs
60        self.mse_loss = mse_loss
61        self.n_objects_per_batch = n_objects_per_batch
62        self.n_sub_iteration = n_sub_iteration
63        self.prompt_generator = prompt_generator
64        self.mask_prob = mask_prob
65        self._kwargs = kwargs
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
from_checkpoint
Serializer
save_checkpoint
load_checkpoint
fit