micro_sam.training.sam_trainer

  1import os
  2import time
  3import random
  4import warnings
  5from typing import Optional
  6
  7import numpy as np
  8
  9import torch
 10from torchvision.utils import make_grid
 11
 12import torch_em
 13from torch_em.trainer.logger_base import TorchEmLogger
 14
 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
 16
 17
 18class SamTrainer(torch_em.trainer.DefaultTrainer):
 19    """Trainer class for training the Segment Anything model.
 20
 21    This class is derived from `torch_em.trainer.DefaultTrainer`.
 22    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 23    for details on its usage and implementation.
 24
 25    Args:
 26        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 27            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 28        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 29            In each sub-iteration new point prompts are sampled where the model was wrong.
 30        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 31            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 32        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 33        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
 34        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
 35        mask_loss: The loss to compare the predicted masks and the targets.
 36        kwargs: The keyword arguments of the DefaultTrainer super class.
 37    """
 38
 39    def __init__(
 40        self,
 41        convert_inputs,
 42        n_sub_iteration: int,
 43        n_objects_per_batch: Optional[int] = None,
 44        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 45        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 46        mask_prob: float = 0.5,
 47        mask_loss: Optional[torch.nn.Module] = None,
 48        **kwargs
 49    ):
 50        if mask_loss is None:
 51            # We have to use the Dice Loss with reduce channel set to None.
 52            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 53            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 54        else:
 55            self.mask_loss = mask_loss
 56
 57        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 58        self.convert_inputs = convert_inputs
 59        self.mse_loss = mse_loss
 60        self.n_objects_per_batch = n_objects_per_batch
 61        self.n_sub_iteration = n_sub_iteration
 62        self.prompt_generator = prompt_generator
 63        self.mask_prob = mask_prob
 64        self._kwargs = kwargs
 65
 66    def _get_prompt_and_multimasking_choices(self, current_iteration):
 67        """Choose the type of prompts we sample for training, and then we call
 68        'convert_inputs' with the correct prompting from here.
 69        """
 70        if current_iteration % 2 == 0:  # sample only a single point per object
 71            n_pos, n_neg = 1, 0
 72            get_boxes = False
 73            multimask_output = True
 74
 75        else:  # sample only a single box per object
 76            n_pos, n_neg = 0, 0
 77            get_boxes = True
 78            multimask_output = False
 79
 80        return n_pos, n_neg, get_boxes, multimask_output
 81
 82    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 83        """Choose the type of prompts we sample for validation, and then we call
 84        'convert_inputs' with the correct prompting from here.
 85        """
 86        if current_iteration % 4 == 0:  # sample only a single point per object
 87            n_pos, n_neg = 1, 0
 88            get_boxes = False
 89            multimask_output = True
 90
 91        elif current_iteration % 4 == 1:  # sample only a single box per object
 92            n_pos, n_neg = 0, 0
 93            get_boxes = True
 94            multimask_output = False
 95
 96        elif current_iteration % 4 == 2:  # sample a random no. of points
 97            pos_range, neg_range = 4, 4
 98
 99            n_pos = np.random.randint(1, pos_range + 1)
100            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
101                n_neg = np.random.randint(1, neg_range + 1)
102            else:
103                n_neg = np.random.randint(0, neg_range + 1)
104            get_boxes = False
105            multimask_output = False
106
107        else:  # sample boxes AND random no. of points
108            # here we can have (1, 0) because we also have box
109            pos_range, neg_range = 4, 4
110
111            n_pos = np.random.randint(1, pos_range + 1)
112            n_neg = np.random.randint(0, neg_range + 1)
113            get_boxes = True
114            multimask_output = False
115
116        return n_pos, n_neg, get_boxes, multimask_output
117
118    def _compute_iou(self, pred, true, eps=1e-7):
119        """Compute the IoU score between the prediction and target.
120        """
121        pred_mask = pred > 0.5  # binarizing the output predictions
122        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
123        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
124        iou = overlap / (union + eps)
125        return iou
126
127    def _compute_loss(self, batched_outputs, y_one_hot):
128        """Compute the loss for one iteration. The loss is made up of two components:
129        - The mask loss: dice score between the predicted masks and targets.
130        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
131        """
132        mask_loss, iou_regression_loss = 0.0, 0.0
133
134        # Loop over the batch.
135        for batch_output, targets in zip(batched_outputs, y_one_hot):
136
137            predicted_objects = torch.sigmoid(batch_output["masks"])
138            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
139            # We swap the axes that go into the dice loss so that the object axis
140            # corresponds to the channel axes. This ensures that the dice is computed
141            # independetly per channel. We do not reduce the channel axis in the dice,
142            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
143            dice_scores = torch.stack([
144                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
145                for i in range(predicted_objects.shape[1])
146            ])
147            dice_scores, _ = torch.min(dice_scores, dim=0)
148
149            # Compute the actual IOU between the predicted and true objects.
150            # The outer loop is for the 1 or 3 predicted masks per true object.
151            with torch.no_grad():
152                true_iou = torch.stack([
153                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
154                ])
155            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
156            # the object axis is back in the first dimension.
157            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
158
159            mask_loss = mask_loss + torch.mean(dice_scores)
160            iou_regression_loss = iou_regression_loss + iou_score
161
162        loss = mask_loss + iou_regression_loss
163
164        return loss, mask_loss, iou_regression_loss
165
166    #
167    # Functionality for iterative prompting loss
168    #
169
170    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
171        # Batched mask and logit (low-res mask) predictions.
172        masks = torch.stack([m["masks"] for m in batched_outputs])
173        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
174
175        # Determine the best IOU across the multi-object prediction axis
176        # and turn this into a mask we can use for indexing.
177        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
178        # for details on the indexing logic.
179        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
180        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
181
182        # Index the mask and logits with the best iou indices.
183        # Note that we squash the first two axes (batch x objects) into one when indexing.
184        # That's why we need to reshape bax into (batch x objects) using a view.
185        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
186        batch_size, n_objects = masks.shape[:2]
187        h, w = masks.shape[-2:]
188        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
189
190        h, w = logits.shape[-2:]
191        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
192
193        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
194        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
195        masks = (masks > 0.0).float()
196        return masks, logits
197
198    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
199        """Compute the loss for several (sub-)iterations of iterative prompting.
200        In each iterations the prompts are updated based on the previous predictions.
201        """
202        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
203
204        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
205
206        for i in range(0, num_subiter):
207            # We do multimasking only in the first sub-iteration as we then pass single prompt
208            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
209            batched_outputs = self.model(
210                batched_inputs=batched_inputs,
211                image_embeddings=image_embeddings,
212                multimask_output=multimask_output if i == 0 else False,
213            )
214
215            # Compute loss for tis sub-iteration.
216            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
217
218            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
219            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
220            with torch.no_grad():
221                net_mean_model_iou = torch.mean(batched_iou_predictions)
222
223            loss = loss + net_loss
224            mask_loss = mask_loss + net_mask_loss
225            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
226            mean_model_iou = mean_model_iou + net_mean_model_iou
227
228            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
229                # Determine the next prompts based on current predictions.
230                with torch.no_grad():
231                    # Get the mask and logit predictions corresponding to the predicted object
232                    # (per actual object) with the best IOU.
233                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
234                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
235
236        loss = loss / num_subiter
237        mask_loss = mask_loss / num_subiter
238        iou_regression_loss = iou_regression_loss / num_subiter
239        mean_model_iou = mean_model_iou / num_subiter
240
241        return loss, mask_loss, iou_regression_loss, mean_model_iou
242
243    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
244        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
245        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
246            # here, we get each object in the pairs and do the point choices per-object
247            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
248
249            # convert the point coordinates to the expected resolution for iterative prompting
250            # NOTE:
251            #   - "only" need to transform the point prompts from the iterative prompting
252            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
253            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
254
255            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
256                if "point_coords" in _inp.keys() else net_coords
257            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
258                if "point_labels" in _inp.keys() else net_labels
259
260            _inp["point_coords"] = updated_point_coords
261            _inp["point_labels"] = updated_point_labels
262
263            if self.mask_prob > 0:
264                # using mask inputs for iterative prompting while training, with a probability
265                use_mask_inputs = (random.random() < self.mask_prob)
266                if use_mask_inputs:
267                    _inp["mask_inputs"] = logits
268                else:  # remove  previously existing mask inputs to avoid using them in next sub-iteration
269                    _inp.pop("mask_inputs", None)
270
271        return batched_inputs
272
273    #
274    # Training Loop
275    #
276
277    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
278        """Compute one hot target (one mask per channel) for the sampled ids
279        and restrict the number of sampled objects to the minimal number in the batch.
280        """
281        assert len(y) == len(sampled_ids)
282
283        # Get the minimal number of objects in this batch.
284        # The number of objects in a patch might be < n_objects_per_batch.
285        # This is why we need to restrict it here to ensure the same
286        # number of objects across the batch.
287        n_objects = min(len(ids) for ids in sampled_ids)
288
289        y = y.to(self.device, non_blocking=True)
290        # Compute the one hot targets for the seg-id.
291        y_one_hot = torch.stack([
292            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
293            for target, ids in zip(y, sampled_ids)
294        ]).float()
295
296        # Also restrict the prompts to the number of objects.
297        batched_inputs = [
298            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
299            for inp in batched_inputs
300        ]
301        return batched_inputs, y_one_hot
302
303    def _interactive_train_iteration(self, x, y):
304        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
305
306        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
307        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
308
309        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
310            batched_inputs=batched_inputs,
311            y_one_hot=y_one_hot,
312            num_subiter=self.n_sub_iteration,
313            multimask_output=multimask_output
314        )
315        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
316
317    def _check_input_normalization(self, x, input_check_done):
318        # The expected data range of the SAM model is 8bit (0-255).
319        # It can easily happen that data is normalized beforehand in training.
320        # For some reasons we don't fully understand this still works, but it
321        # should still be avoided and is very detrimental in some settings
322        # (e.g. when freezing the image encoder)
323        # We check once per epoch if the data seems to be normalized already and
324        # raise a warning if this is the case.
325        if not input_check_done:
326            data_min, data_max = x.min(), x.max()
327            if (data_min < 0) or (data_max < 1):
328                warnings.warn(
329                    "It looks like you are normalizing the training data."
330                    "The SAM model takes care of normalization, so it is better to not do this."
331                    "We recommend to remove data normalization and input data in the range [0, 255]."
332                )
333            input_check_done = True
334
335        return input_check_done
336
337    def _train_epoch_impl(self, progress, forward_context, backprop):
338        self.model.train()
339
340        input_check_done = False
341
342        n_iter = 0
343        t_per_iter = time.time()
344        for x, y in self.train_loader:
345            input_check_done = self._check_input_normalization(x, input_check_done)
346
347            self.optimizer.zero_grad()
348
349            with forward_context():
350                (loss, mask_loss, iou_regression_loss, model_iou,
351                 sampled_binary_y) = self._interactive_train_iteration(x, y)
352
353            backprop(loss)
354
355            if self.logger is not None:
356                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
357                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
358                self.logger.log_train(
359                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
360                )
361
362            self._iteration += 1
363            n_iter += 1
364            if self._iteration >= self.max_iteration:
365                break
366            progress.update(1)
367
368        t_per_iter = (time.time() - t_per_iter) / n_iter
369        return t_per_iter
370
371    def _interactive_val_iteration(self, x, y, val_iteration):
372        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
373
374        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
375        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
376
377        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
378
379        batched_outputs = self.model(
380            batched_inputs=batched_inputs,
381            image_embeddings=image_embeddings,
382            multimask_output=multimask_output,
383        )
384
385        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
386        # We use the dice loss over the masks as metric.
387        metric = mask_loss
388        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
389
390        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
391
392    def _validate_impl(self, forward_context):
393        self.model.eval()
394
395        input_check_done = False
396
397        val_iteration = 0
398        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
399
400        with torch.no_grad():
401            for x, y in self.val_loader:
402                input_check_done = self._check_input_normalization(x, input_check_done)
403
404                with forward_context():
405                    (loss, mask_loss, iou_regression_loss, model_iou,
406                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
407
408                loss_val += loss.item()
409                metric_val += metric.item()
410                model_iou_val += model_iou.item()
411                val_iteration += 1
412
413        loss_val /= len(self.val_loader)
414        metric_val /= len(self.val_loader)
415        model_iou_val /= len(self.val_loader)
416        print()
417        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
418
419        if self.logger is not None:
420            self.logger.log_validation(
421                self._iteration, metric_val, loss_val, x, y,
422                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
423            )
424
425        return metric_val
426
427
428class SamLogger(TorchEmLogger):
429    """@private"""
430    def __init__(self, trainer, save_root, **unused_kwargs):
431        super().__init__(trainer, save_root)
432        self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name)
433        os.makedirs(self.log_dir, exist_ok=True)
434
435        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
436        self.log_image_interval = trainer.log_image_interval
437
438    def add_image(self, x, y, samples, name, step):
439        self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
440        self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
441        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
442        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
443
444    def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
445        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
446        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
447        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
448        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
449        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
450        if step % self.log_image_interval == 0:
451            self.add_image(x, y, samples, "train", step)
452
453    def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
454        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
455        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
456        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
457        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
458        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
459        self.add_image(x, y, samples, "validation", step)
class SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 19class SamTrainer(torch_em.trainer.DefaultTrainer):
 20    """Trainer class for training the Segment Anything model.
 21
 22    This class is derived from `torch_em.trainer.DefaultTrainer`.
 23    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 24    for details on its usage and implementation.
 25
 26    Args:
 27        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 28            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 29        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 30            In each sub-iteration new point prompts are sampled where the model was wrong.
 31        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 32            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 33        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 34        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
 35        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`)
 36        mask_loss: The loss to compare the predicted masks and the targets.
 37        kwargs: The keyword arguments of the DefaultTrainer super class.
 38    """
 39
 40    def __init__(
 41        self,
 42        convert_inputs,
 43        n_sub_iteration: int,
 44        n_objects_per_batch: Optional[int] = None,
 45        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 46        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 47        mask_prob: float = 0.5,
 48        mask_loss: Optional[torch.nn.Module] = None,
 49        **kwargs
 50    ):
 51        if mask_loss is None:
 52            # We have to use the Dice Loss with reduce channel set to None.
 53            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 54            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 55        else:
 56            self.mask_loss = mask_loss
 57
 58        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 59        self.convert_inputs = convert_inputs
 60        self.mse_loss = mse_loss
 61        self.n_objects_per_batch = n_objects_per_batch
 62        self.n_sub_iteration = n_sub_iteration
 63        self.prompt_generator = prompt_generator
 64        self.mask_prob = mask_prob
 65        self._kwargs = kwargs
 66
 67    def _get_prompt_and_multimasking_choices(self, current_iteration):
 68        """Choose the type of prompts we sample for training, and then we call
 69        'convert_inputs' with the correct prompting from here.
 70        """
 71        if current_iteration % 2 == 0:  # sample only a single point per object
 72            n_pos, n_neg = 1, 0
 73            get_boxes = False
 74            multimask_output = True
 75
 76        else:  # sample only a single box per object
 77            n_pos, n_neg = 0, 0
 78            get_boxes = True
 79            multimask_output = False
 80
 81        return n_pos, n_neg, get_boxes, multimask_output
 82
 83    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 84        """Choose the type of prompts we sample for validation, and then we call
 85        'convert_inputs' with the correct prompting from here.
 86        """
 87        if current_iteration % 4 == 0:  # sample only a single point per object
 88            n_pos, n_neg = 1, 0
 89            get_boxes = False
 90            multimask_output = True
 91
 92        elif current_iteration % 4 == 1:  # sample only a single box per object
 93            n_pos, n_neg = 0, 0
 94            get_boxes = True
 95            multimask_output = False
 96
 97        elif current_iteration % 4 == 2:  # sample a random no. of points
 98            pos_range, neg_range = 4, 4
 99
100            n_pos = np.random.randint(1, pos_range + 1)
101            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
102                n_neg = np.random.randint(1, neg_range + 1)
103            else:
104                n_neg = np.random.randint(0, neg_range + 1)
105            get_boxes = False
106            multimask_output = False
107
108        else:  # sample boxes AND random no. of points
109            # here we can have (1, 0) because we also have box
110            pos_range, neg_range = 4, 4
111
112            n_pos = np.random.randint(1, pos_range + 1)
113            n_neg = np.random.randint(0, neg_range + 1)
114            get_boxes = True
115            multimask_output = False
116
117        return n_pos, n_neg, get_boxes, multimask_output
118
119    def _compute_iou(self, pred, true, eps=1e-7):
120        """Compute the IoU score between the prediction and target.
121        """
122        pred_mask = pred > 0.5  # binarizing the output predictions
123        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
124        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
125        iou = overlap / (union + eps)
126        return iou
127
128    def _compute_loss(self, batched_outputs, y_one_hot):
129        """Compute the loss for one iteration. The loss is made up of two components:
130        - The mask loss: dice score between the predicted masks and targets.
131        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
132        """
133        mask_loss, iou_regression_loss = 0.0, 0.0
134
135        # Loop over the batch.
136        for batch_output, targets in zip(batched_outputs, y_one_hot):
137
138            predicted_objects = torch.sigmoid(batch_output["masks"])
139            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
140            # We swap the axes that go into the dice loss so that the object axis
141            # corresponds to the channel axes. This ensures that the dice is computed
142            # independetly per channel. We do not reduce the channel axis in the dice,
143            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
144            dice_scores = torch.stack([
145                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
146                for i in range(predicted_objects.shape[1])
147            ])
148            dice_scores, _ = torch.min(dice_scores, dim=0)
149
150            # Compute the actual IOU between the predicted and true objects.
151            # The outer loop is for the 1 or 3 predicted masks per true object.
152            with torch.no_grad():
153                true_iou = torch.stack([
154                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
155                ])
156            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
157            # the object axis is back in the first dimension.
158            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
159
160            mask_loss = mask_loss + torch.mean(dice_scores)
161            iou_regression_loss = iou_regression_loss + iou_score
162
163        loss = mask_loss + iou_regression_loss
164
165        return loss, mask_loss, iou_regression_loss
166
167    #
168    # Functionality for iterative prompting loss
169    #
170
171    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
172        # Batched mask and logit (low-res mask) predictions.
173        masks = torch.stack([m["masks"] for m in batched_outputs])
174        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
175
176        # Determine the best IOU across the multi-object prediction axis
177        # and turn this into a mask we can use for indexing.
178        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
179        # for details on the indexing logic.
180        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
181        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
182
183        # Index the mask and logits with the best iou indices.
184        # Note that we squash the first two axes (batch x objects) into one when indexing.
185        # That's why we need to reshape bax into (batch x objects) using a view.
186        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
187        batch_size, n_objects = masks.shape[:2]
188        h, w = masks.shape[-2:]
189        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
190
191        h, w = logits.shape[-2:]
192        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
193
194        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
195        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
196        masks = (masks > 0.0).float()
197        return masks, logits
198
199    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
200        """Compute the loss for several (sub-)iterations of iterative prompting.
201        In each iterations the prompts are updated based on the previous predictions.
202        """
203        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
204
205        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
206
207        for i in range(0, num_subiter):
208            # We do multimasking only in the first sub-iteration as we then pass single prompt
209            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
210            batched_outputs = self.model(
211                batched_inputs=batched_inputs,
212                image_embeddings=image_embeddings,
213                multimask_output=multimask_output if i == 0 else False,
214            )
215
216            # Compute loss for tis sub-iteration.
217            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
218
219            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
220            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
221            with torch.no_grad():
222                net_mean_model_iou = torch.mean(batched_iou_predictions)
223
224            loss = loss + net_loss
225            mask_loss = mask_loss + net_mask_loss
226            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
227            mean_model_iou = mean_model_iou + net_mean_model_iou
228
229            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
230                # Determine the next prompts based on current predictions.
231                with torch.no_grad():
232                    # Get the mask and logit predictions corresponding to the predicted object
233                    # (per actual object) with the best IOU.
234                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
235                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits)
236
237        loss = loss / num_subiter
238        mask_loss = mask_loss / num_subiter
239        iou_regression_loss = iou_regression_loss / num_subiter
240        mean_model_iou = mean_model_iou / num_subiter
241
242        return loss, mask_loss, iou_regression_loss, mean_model_iou
243
244    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks):
245        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
246        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
247            # here, we get each object in the pairs and do the point choices per-object
248            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
249
250            # convert the point coordinates to the expected resolution for iterative prompting
251            # NOTE:
252            #   - "only" need to transform the point prompts from the iterative prompting
253            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
254            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
255
256            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
257                if "point_coords" in _inp.keys() else net_coords
258            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
259                if "point_labels" in _inp.keys() else net_labels
260
261            _inp["point_coords"] = updated_point_coords
262            _inp["point_labels"] = updated_point_labels
263
264            if self.mask_prob > 0:
265                # using mask inputs for iterative prompting while training, with a probability
266                use_mask_inputs = (random.random() < self.mask_prob)
267                if use_mask_inputs:
268                    _inp["mask_inputs"] = logits
269                else:  # remove  previously existing mask inputs to avoid using them in next sub-iteration
270                    _inp.pop("mask_inputs", None)
271
272        return batched_inputs
273
274    #
275    # Training Loop
276    #
277
278    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
279        """Compute one hot target (one mask per channel) for the sampled ids
280        and restrict the number of sampled objects to the minimal number in the batch.
281        """
282        assert len(y) == len(sampled_ids)
283
284        # Get the minimal number of objects in this batch.
285        # The number of objects in a patch might be < n_objects_per_batch.
286        # This is why we need to restrict it here to ensure the same
287        # number of objects across the batch.
288        n_objects = min(len(ids) for ids in sampled_ids)
289
290        y = y.to(self.device, non_blocking=True)
291        # Compute the one hot targets for the seg-id.
292        y_one_hot = torch.stack([
293            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
294            for target, ids in zip(y, sampled_ids)
295        ]).float()
296
297        # Also restrict the prompts to the number of objects.
298        batched_inputs = [
299            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
300            for inp in batched_inputs
301        ]
302        return batched_inputs, y_one_hot
303
304    def _interactive_train_iteration(self, x, y):
305        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
306
307        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
308        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
309
310        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
311            batched_inputs=batched_inputs,
312            y_one_hot=y_one_hot,
313            num_subiter=self.n_sub_iteration,
314            multimask_output=multimask_output
315        )
316        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
317
318    def _check_input_normalization(self, x, input_check_done):
319        # The expected data range of the SAM model is 8bit (0-255).
320        # It can easily happen that data is normalized beforehand in training.
321        # For some reasons we don't fully understand this still works, but it
322        # should still be avoided and is very detrimental in some settings
323        # (e.g. when freezing the image encoder)
324        # We check once per epoch if the data seems to be normalized already and
325        # raise a warning if this is the case.
326        if not input_check_done:
327            data_min, data_max = x.min(), x.max()
328            if (data_min < 0) or (data_max < 1):
329                warnings.warn(
330                    "It looks like you are normalizing the training data."
331                    "The SAM model takes care of normalization, so it is better to not do this."
332                    "We recommend to remove data normalization and input data in the range [0, 255]."
333                )
334            input_check_done = True
335
336        return input_check_done
337
338    def _train_epoch_impl(self, progress, forward_context, backprop):
339        self.model.train()
340
341        input_check_done = False
342
343        n_iter = 0
344        t_per_iter = time.time()
345        for x, y in self.train_loader:
346            input_check_done = self._check_input_normalization(x, input_check_done)
347
348            self.optimizer.zero_grad()
349
350            with forward_context():
351                (loss, mask_loss, iou_regression_loss, model_iou,
352                 sampled_binary_y) = self._interactive_train_iteration(x, y)
353
354            backprop(loss)
355
356            if self.logger is not None:
357                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
358                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
359                self.logger.log_train(
360                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
361                )
362
363            self._iteration += 1
364            n_iter += 1
365            if self._iteration >= self.max_iteration:
366                break
367            progress.update(1)
368
369        t_per_iter = (time.time() - t_per_iter) / n_iter
370        return t_per_iter
371
372    def _interactive_val_iteration(self, x, y, val_iteration):
373        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
374
375        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
376        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
377
378        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
379
380        batched_outputs = self.model(
381            batched_inputs=batched_inputs,
382            image_embeddings=image_embeddings,
383            multimask_output=multimask_output,
384        )
385
386        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
387        # We use the dice loss over the masks as metric.
388        metric = mask_loss
389        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
390
391        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
392
393    def _validate_impl(self, forward_context):
394        self.model.eval()
395
396        input_check_done = False
397
398        val_iteration = 0
399        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
400
401        with torch.no_grad():
402            for x, y in self.val_loader:
403                input_check_done = self._check_input_normalization(x, input_check_done)
404
405                with forward_context():
406                    (loss, mask_loss, iou_regression_loss, model_iou,
407                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
408
409                loss_val += loss.item()
410                metric_val += metric.item()
411                model_iou_val += model_iou.item()
412                val_iteration += 1
413
414        loss_val /= len(self.val_loader)
415        metric_val /= len(self.val_loader)
416        model_iou_val /= len(self.val_loader)
417        print()
418        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
419
420        if self.logger is not None:
421            self.logger.log_validation(
422                self._iteration, metric_val, loss_val, x, y,
423                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
424            )
425
426        return metric_val

Trainer class for training the Segment Anything model.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSamInputs can be used here.
  • n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
  • n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
  • mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
  • prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training
  • mask_prob: The probability of using the mask inputs in the iterative prompting (per n_sub_iteration)
  • mask_loss: The loss to compare the predicted masks and the targets.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
40    def __init__(
41        self,
42        convert_inputs,
43        n_sub_iteration: int,
44        n_objects_per_batch: Optional[int] = None,
45        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
46        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
47        mask_prob: float = 0.5,
48        mask_loss: Optional[torch.nn.Module] = None,
49        **kwargs
50    ):
51        if mask_loss is None:
52            # We have to use the Dice Loss with reduce channel set to None.
53            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
54            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
55        else:
56            self.mask_loss = mask_loss
57
58        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
59        self.convert_inputs = convert_inputs
60        self.mse_loss = mse_loss
61        self.n_objects_per_batch = n_objects_per_batch
62        self.n_sub_iteration = n_sub_iteration
63        self.prompt_generator = prompt_generator
64        self.mask_prob = mask_prob
65        self._kwargs = kwargs
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit