micro_sam.training.sam_trainer

  1import os
  2import time
  3import random
  4import warnings
  5from typing import Optional, Callable
  6
  7import numpy as np
  8
  9import torch
 10from torchvision.utils import make_grid
 11
 12import torch_em
 13from torch_em.trainer.logger_base import TorchEmLogger
 14
 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
 16
 17
 18class SamTrainer(torch_em.trainer.DefaultTrainer):
 19    """Trainer class for training the Segment Anything model.
 20
 21    This class is derived from `torch_em.trainer.DefaultTrainer`.
 22    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 23    for details on its usage and implementation.
 24
 25    Args:
 26        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 27            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 28        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 29            In each sub-iteration new point prompts are sampled where the model was wrong.
 30        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 31            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 32        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 33            By default, set to the expected mse loss function.
 34        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training.
 35            Already allocated with the desired prompt generator by default.
 36        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`).
 37            By default, set to '0.5'.
 38        mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
 39        kwargs: The keyword arguments of the `DefaultTrainer` super class.
 40    """
 41
 42    def __init__(
 43        self,
 44        convert_inputs: Callable,
 45        n_sub_iteration: int,
 46        n_objects_per_batch: Optional[int] = None,
 47        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 48        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 49        mask_prob: float = 0.5,
 50        mask_loss: Optional[torch.nn.Module] = None,
 51        **kwargs
 52    ):
 53        if mask_loss is None:
 54            # We have to use the Dice Loss with reduce channel set to None.
 55            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 56            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 57        else:
 58            self.mask_loss = mask_loss
 59
 60        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 61        self.convert_inputs = convert_inputs
 62        self.mse_loss = mse_loss
 63        self.n_objects_per_batch = n_objects_per_batch
 64        self.n_sub_iteration = n_sub_iteration
 65        self.prompt_generator = prompt_generator
 66        self.mask_prob = mask_prob
 67        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
 68        self._kwargs = kwargs
 69
 70    def _get_prompt_and_multimasking_choices(self, current_iteration):
 71        """Choose the type of prompts we sample for training, and then we call
 72        'convert_inputs' with the correct prompting from here.
 73        """
 74        if current_iteration % 2 == 0:  # sample only a single point per object
 75            n_pos, n_neg = 1, 0
 76            get_boxes = False
 77            multimask_output = True
 78
 79        else:  # sample only a single box per object
 80            n_pos, n_neg = 0, 0
 81            get_boxes = True
 82            multimask_output = False
 83
 84        return n_pos, n_neg, get_boxes, multimask_output
 85
 86    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 87        """Choose the type of prompts we sample for validation, and then we call
 88        'convert_inputs' with the correct prompting from here.
 89        """
 90        if current_iteration % 4 == 0:  # sample only a single point per object
 91            n_pos, n_neg = 1, 0
 92            get_boxes = False
 93            multimask_output = True
 94
 95        elif current_iteration % 4 == 1:  # sample only a single box per object
 96            n_pos, n_neg = 0, 0
 97            get_boxes = True
 98            multimask_output = False
 99
100        elif current_iteration % 4 == 2:  # sample a random no. of points
101            pos_range, neg_range = 4, 4
102
103            n_pos = np.random.randint(1, pos_range + 1)
104            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
105                n_neg = np.random.randint(1, neg_range + 1)
106            else:
107                n_neg = np.random.randint(0, neg_range + 1)
108            get_boxes = False
109            multimask_output = False
110
111        else:  # sample boxes AND random no. of points
112            # here we can have (1, 0) because we also have box
113            pos_range, neg_range = 4, 4
114
115            n_pos = np.random.randint(1, pos_range + 1)
116            n_neg = np.random.randint(0, neg_range + 1)
117            get_boxes = True
118            multimask_output = False
119
120        return n_pos, n_neg, get_boxes, multimask_output
121
122    def _compute_iou(self, pred, true, eps=1e-7):
123        """Compute the IoU score between the prediction and target.
124        """
125        pred_mask = pred > 0.5  # binarizing the output predictions
126        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
127        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
128        iou = overlap / (union + eps)
129        return iou
130
131    def _compute_loss(self, batched_outputs, y_one_hot):
132        """Compute the loss for one iteration. The loss is made up of two components:
133        - The mask loss: dice score between the predicted masks and targets.
134        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
135        """
136        mask_loss, iou_regression_loss = 0.0, 0.0
137        batch_size = len(batched_outputs)
138
139        # Loop over the batch.
140        for batch_output, targets in zip(batched_outputs, y_one_hot):
141
142            predicted_objects = torch.sigmoid(batch_output["masks"])
143            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
144            # We swap the axes that go into the dice loss so that the object axis
145            # corresponds to the channel axes. This ensures that the dice is computed
146            # independetly per channel. We do not reduce the channel axis in the dice,
147            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
148            dice_scores = torch.stack([
149                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
150                for i in range(predicted_objects.shape[1])
151            ])
152            dice_scores, _ = torch.min(dice_scores, dim=0)
153
154            # Compute the actual IOU between the predicted and true objects.
155            # The outer loop is for the 1 or 3 predicted masks per true object.
156            with torch.no_grad():
157                true_iou = torch.stack([
158                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
159                ])
160            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
161            # the object axis is back in the first dimension.
162            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
163
164            mask_loss = mask_loss + torch.mean(dice_scores)
165            iou_regression_loss = iou_regression_loss + iou_score
166
167        # Normalize by batch size so that loss/metric are comparable across batch sizes.
168        mask_loss = mask_loss / batch_size
169        iou_regression_loss = iou_regression_loss / batch_size
170        loss = mask_loss + iou_regression_loss
171
172        return loss, mask_loss, iou_regression_loss
173
174    #
175    # Functionality for iterative prompting loss
176    #
177
178    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
179        # Batched mask and logit (low-res mask) predictions.
180        masks = torch.stack([m["masks"] for m in batched_outputs])
181        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
182
183        # Determine the best IOU across the multi-object prediction axis
184        # and turn this into a mask we can use for indexing.
185        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
186        # for details on the indexing logic.
187        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
188        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
189
190        # Index the mask and logits with the best iou indices.
191        # Note that we squash the first two axes (batch x objects) into one when indexing.
192        # That's why we need to reshape bax into (batch x objects) using a view.
193        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
194        batch_size, n_objects = masks.shape[:2]
195        h, w = masks.shape[-2:]
196        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
197
198        h, w = logits.shape[-2:]
199        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
200
201        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
202        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
203        masks = (masks > 0.0).float()
204        return masks, logits
205
206    def _use_mask_inputs(self, batched_inputs, y_one_hot):
207        # Whether to use masks per training top-iteration.
208        use_mask_inputs = False  # determines if each sub-iteration will use mask inputs as prompts or not.
209        use_zero_mask = False  # determines if the zeroth iteration will use zeros as mask inputs.
210
211        if self.mask_prob == 1:  # i.e. always use masks.
212            use_mask_inputs = True  # we would like to use mask inputs in all sub-iterations.
213            use_zero_mask = self.is_data_parallel  # we would like to use zeros as mask inputs for zeroth iteration.
214
215        elif self.mask_prob > 0:  # i.e. if we use mask inputs with a probability.
216            if self.is_data_parallel:  # if training on multiple GPUs.
217                if torch.distributed.get_rank() == 0:  # device with rank 0.
218                    use_mask_inputs_tensor = torch.tensor(
219                        random.random() < self.mask_prob, dtype=torch.uint8, device=self.device,
220                    )
221                else:  # on other devices, we do not need this parameter at this stage.
222                    use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device)
223
224                # Broadcast the value to all devices (ranks).
225                torch.distributed.broadcast(use_mask_inputs_tensor, src=0)
226
227                # And convert it back to our desired boolean value.
228                use_mask_inputs = bool(use_mask_inputs_tensor.item())
229                use_zero_mask = use_mask_inputs  # provides zeros as mask inputs.
230            else:  # training on a single GPU.
231                use_mask_inputs = None
232
233        if use_zero_mask:
234            # We use zeros as mask inputs for the zeroth iteration.
235            y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256))
236
237            # Add zeros as mask inputs to batched inputs.
238            for bi, curr_masks in zip(batched_inputs, y_zeros):
239                bi["mask_inputs"] = curr_masks
240
241        return batched_inputs, use_mask_inputs
242
243    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
244        """Compute the loss for several (sub-)iterations of iterative prompting.
245        In each iterations the prompts are updated based on the previous predictions.
246        """
247        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
248
249        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
250
251        # Whether to use mask inputs in each sub-iteration.
252        batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot)
253
254        for i in range(0, num_subiter):
255            # We do multimasking only in the first sub-iteration as we then pass single prompt
256            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
257            batched_outputs = self.model(
258                batched_inputs=batched_inputs,
259                image_embeddings=image_embeddings,
260                multimask_output=multimask_output if i == 0 else False,
261            )
262
263            # Compute loss for this sub-iteration.
264            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
265
266            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
267            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
268            with torch.no_grad():
269                net_mean_model_iou = torch.mean(batched_iou_predictions)
270
271            loss = loss + net_loss
272            mask_loss = mask_loss + net_mask_loss
273            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
274            mean_model_iou = mean_model_iou + net_mean_model_iou
275
276            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
277                # Determine the next prompts based on current predictions.
278                with torch.no_grad():
279                    # Get the mask and logit predictions corresponding to the predicted object
280                    # (per actual object) with the best IOU.
281                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
282                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs)
283
284        loss = loss / num_subiter
285        mask_loss = mask_loss / num_subiter
286        iou_regression_loss = iou_regression_loss / num_subiter
287        mean_model_iou = mean_model_iou / num_subiter
288
289        return loss, mask_loss, iou_regression_loss, mean_model_iou
290
291    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs):
292        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
293        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
294            # here, we get each object in the pairs and do the point choices per-object
295            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
296
297            # convert the point coordinates to the expected resolution for iterative prompting
298            # NOTE:
299            #   - "only" need to transform the point prompts from the iterative prompting
300            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
301            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
302
303            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
304                if "point_coords" in _inp.keys() else net_coords
305            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
306                if "point_labels" in _inp.keys() else net_labels
307
308            _inp["point_coords"] = updated_point_coords
309            _inp["point_labels"] = updated_point_labels
310
311            if self.is_data_parallel:  # multi-GPU training
312                use_mask_inputs_this_iter = use_mask_inputs
313            else:  # single GPU training
314                if self.mask_prob > 0:
315                    # using mask inputs for iterative prompting while training, with a probability
316                    use_mask_inputs = (random.random() < self.mask_prob)
317                else:  # otherwise we assume it is 0 and do not need the generator to decide.
318                    use_mask_inputs = False
319
320                use_mask_inputs_this_iter = use_mask_inputs
321
322            if use_mask_inputs_this_iter:
323                _inp["mask_inputs"] = logits
324            else:  # remove previously existing mask inputs to avoid using them in next sub-iteration.
325                _inp.pop("mask_inputs", None)
326
327        return batched_inputs
328
329    #
330    # Training Loop
331    #
332
333    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
334        """Compute one hot target (one mask per channel) for the sampled ids
335        and restrict the number of sampled objects to the minimal number in the batch.
336        """
337        assert len(y) == len(sampled_ids)
338
339        # Get the minimal number of objects in this batch.
340        # The number of objects in a patch might be < n_objects_per_batch.
341        # This is why we need to restrict it here to ensure the same
342        # number of objects across the batch.
343        n_objects = min(len(ids) for ids in sampled_ids)
344
345        y = y.to(self.device, non_blocking=True)
346        # Compute the one hot targets for the seg-id.
347        y_one_hot = torch.stack([
348            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
349            for target, ids in zip(y, sampled_ids)
350        ]).float()
351
352        # Also restrict the prompts to the number of objects.
353        batched_inputs = [
354            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
355            for inp in batched_inputs
356        ]
357        return batched_inputs, y_one_hot
358
359    def _interactive_train_iteration(self, x, y):
360        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
361
362        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
363        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
364
365        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
366            batched_inputs=batched_inputs,
367            y_one_hot=y_one_hot,
368            num_subiter=self.n_sub_iteration,
369            multimask_output=multimask_output
370        )
371        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
372
373    def _check_input_normalization(self, x, input_check_done):
374        # The expected data range of the SAM model is 8bit (0-255).
375        # It can easily happen that data is normalized beforehand in training.
376        # For some reasons we don't fully understand this still works, but it
377        # should still be avoided and is very detrimental in some settings
378        # (e.g. when freezing the image encoder)
379        # We check once per epoch if the data seems to be normalized already and
380        # raise a warning if this is the case.
381        if not input_check_done:
382            data_min, data_max = x.min(), x.max()
383            if (data_min < 0) or (data_max < 1):
384                warnings.warn(
385                    "It looks like you are normalizing the training data. "
386                    "The SAM model takes care of normalization, so it is better to not do this. "
387                    "We recommend to remove data normalization and input data in the range [0, 255]."
388                )
389            input_check_done = True
390
391        return input_check_done
392
393    def _train_epoch_impl(self, progress, forward_context, backprop):
394        self.model.train()
395
396        input_check_done = False
397
398        n_iter = 0
399        t_per_iter = time.time()
400        for x, y in self.train_loader:
401            input_check_done = self._check_input_normalization(x, input_check_done)
402
403            self.optimizer.zero_grad()
404
405            with forward_context():
406                (loss, mask_loss, iou_regression_loss, model_iou,
407                 sampled_binary_y) = self._interactive_train_iteration(x, y)
408
409            backprop(loss)
410
411            if self.logger is not None:
412                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
413                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
414                self.logger.log_train(
415                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
416                )
417
418            self._iteration += 1
419            n_iter += 1
420            if self._iteration >= self.max_iteration:
421                break
422            progress.update(1)
423
424        t_per_iter = (time.time() - t_per_iter) / n_iter
425        return t_per_iter
426
427    def _interactive_val_iteration(self, x, y, val_iteration):
428        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
429
430        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
431        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
432
433        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
434
435        batched_outputs = self.model(
436            batched_inputs=batched_inputs,
437            image_embeddings=image_embeddings,
438            multimask_output=multimask_output,
439        )
440
441        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
442        # We use the dice loss over the masks as metric.
443        metric = mask_loss
444        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
445
446        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
447
448    def _validate_impl(self, forward_context):
449        self.model.eval()
450
451        input_check_done = False
452
453        val_iteration = 0
454        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
455        mask_loss_val, iou_loss_val = 0.0, 0.0
456
457        with torch.no_grad():
458            for x, y in self.val_loader:
459                input_check_done = self._check_input_normalization(x, input_check_done)
460
461                with forward_context():
462                    (loss, mask_loss, iou_regression_loss, model_iou,
463                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
464
465                loss_val += loss.item()
466                metric_val += metric.item()
467                mask_loss_val += mask_loss.item()
468                iou_loss_val += iou_regression_loss.item()
469                model_iou_val += model_iou.item()
470                val_iteration += 1
471
472        loss_val /= len(self.val_loader)
473        metric_val /= len(self.val_loader)
474        mask_loss_val /= len(self.val_loader)
475        iou_loss_val /= len(self.val_loader)
476        model_iou_val /= len(self.val_loader)
477        print()
478        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
479
480        if self.logger is not None:
481            self.logger.log_validation(
482                self._iteration, metric_val, loss_val, x, y,
483                sampled_binary_y, mask_loss_val, iou_loss_val, model_iou_val
484            )
485
486        return metric_val
487
488
489class SamLogger(TorchEmLogger):
490    """@private"""
491    def __init__(self, trainer, save_root, **unused_kwargs):
492        super().__init__(trainer, save_root)
493        self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name)
494        os.makedirs(self.log_dir, exist_ok=True)
495
496        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
497        self.log_image_interval = trainer.log_image_interval
498
499    def add_image(self, x, y, samples, name, step):
500        self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
501        self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
502        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
503        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
504
505    def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
506        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
507        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
508        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
509        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
510        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
511        if step % self.log_image_interval == 0:
512            self.add_image(x, y, samples, "train", step)
513
514    def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
515        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
516        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
517        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
518        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
519        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
520        self.add_image(x, y, samples, "validation", step)
class SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 19class SamTrainer(torch_em.trainer.DefaultTrainer):
 20    """Trainer class for training the Segment Anything model.
 21
 22    This class is derived from `torch_em.trainer.DefaultTrainer`.
 23    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 24    for details on its usage and implementation.
 25
 26    Args:
 27        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 28            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 29        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 30            In each sub-iteration new point prompts are sampled where the model was wrong.
 31        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 32            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 33        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 34            By default, set to the expected mse loss function.
 35        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training.
 36            Already allocated with the desired prompt generator by default.
 37        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`).
 38            By default, set to '0.5'.
 39        mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
 40        kwargs: The keyword arguments of the `DefaultTrainer` super class.
 41    """
 42
 43    def __init__(
 44        self,
 45        convert_inputs: Callable,
 46        n_sub_iteration: int,
 47        n_objects_per_batch: Optional[int] = None,
 48        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 49        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 50        mask_prob: float = 0.5,
 51        mask_loss: Optional[torch.nn.Module] = None,
 52        **kwargs
 53    ):
 54        if mask_loss is None:
 55            # We have to use the Dice Loss with reduce channel set to None.
 56            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 57            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 58        else:
 59            self.mask_loss = mask_loss
 60
 61        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 62        self.convert_inputs = convert_inputs
 63        self.mse_loss = mse_loss
 64        self.n_objects_per_batch = n_objects_per_batch
 65        self.n_sub_iteration = n_sub_iteration
 66        self.prompt_generator = prompt_generator
 67        self.mask_prob = mask_prob
 68        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
 69        self._kwargs = kwargs
 70
 71    def _get_prompt_and_multimasking_choices(self, current_iteration):
 72        """Choose the type of prompts we sample for training, and then we call
 73        'convert_inputs' with the correct prompting from here.
 74        """
 75        if current_iteration % 2 == 0:  # sample only a single point per object
 76            n_pos, n_neg = 1, 0
 77            get_boxes = False
 78            multimask_output = True
 79
 80        else:  # sample only a single box per object
 81            n_pos, n_neg = 0, 0
 82            get_boxes = True
 83            multimask_output = False
 84
 85        return n_pos, n_neg, get_boxes, multimask_output
 86
 87    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 88        """Choose the type of prompts we sample for validation, and then we call
 89        'convert_inputs' with the correct prompting from here.
 90        """
 91        if current_iteration % 4 == 0:  # sample only a single point per object
 92            n_pos, n_neg = 1, 0
 93            get_boxes = False
 94            multimask_output = True
 95
 96        elif current_iteration % 4 == 1:  # sample only a single box per object
 97            n_pos, n_neg = 0, 0
 98            get_boxes = True
 99            multimask_output = False
100
101        elif current_iteration % 4 == 2:  # sample a random no. of points
102            pos_range, neg_range = 4, 4
103
104            n_pos = np.random.randint(1, pos_range + 1)
105            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
106                n_neg = np.random.randint(1, neg_range + 1)
107            else:
108                n_neg = np.random.randint(0, neg_range + 1)
109            get_boxes = False
110            multimask_output = False
111
112        else:  # sample boxes AND random no. of points
113            # here we can have (1, 0) because we also have box
114            pos_range, neg_range = 4, 4
115
116            n_pos = np.random.randint(1, pos_range + 1)
117            n_neg = np.random.randint(0, neg_range + 1)
118            get_boxes = True
119            multimask_output = False
120
121        return n_pos, n_neg, get_boxes, multimask_output
122
123    def _compute_iou(self, pred, true, eps=1e-7):
124        """Compute the IoU score between the prediction and target.
125        """
126        pred_mask = pred > 0.5  # binarizing the output predictions
127        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
128        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
129        iou = overlap / (union + eps)
130        return iou
131
132    def _compute_loss(self, batched_outputs, y_one_hot):
133        """Compute the loss for one iteration. The loss is made up of two components:
134        - The mask loss: dice score between the predicted masks and targets.
135        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
136        """
137        mask_loss, iou_regression_loss = 0.0, 0.0
138        batch_size = len(batched_outputs)
139
140        # Loop over the batch.
141        for batch_output, targets in zip(batched_outputs, y_one_hot):
142
143            predicted_objects = torch.sigmoid(batch_output["masks"])
144            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
145            # We swap the axes that go into the dice loss so that the object axis
146            # corresponds to the channel axes. This ensures that the dice is computed
147            # independetly per channel. We do not reduce the channel axis in the dice,
148            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
149            dice_scores = torch.stack([
150                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
151                for i in range(predicted_objects.shape[1])
152            ])
153            dice_scores, _ = torch.min(dice_scores, dim=0)
154
155            # Compute the actual IOU between the predicted and true objects.
156            # The outer loop is for the 1 or 3 predicted masks per true object.
157            with torch.no_grad():
158                true_iou = torch.stack([
159                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
160                ])
161            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
162            # the object axis is back in the first dimension.
163            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
164
165            mask_loss = mask_loss + torch.mean(dice_scores)
166            iou_regression_loss = iou_regression_loss + iou_score
167
168        # Normalize by batch size so that loss/metric are comparable across batch sizes.
169        mask_loss = mask_loss / batch_size
170        iou_regression_loss = iou_regression_loss / batch_size
171        loss = mask_loss + iou_regression_loss
172
173        return loss, mask_loss, iou_regression_loss
174
175    #
176    # Functionality for iterative prompting loss
177    #
178
179    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
180        # Batched mask and logit (low-res mask) predictions.
181        masks = torch.stack([m["masks"] for m in batched_outputs])
182        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
183
184        # Determine the best IOU across the multi-object prediction axis
185        # and turn this into a mask we can use for indexing.
186        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
187        # for details on the indexing logic.
188        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
189        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
190
191        # Index the mask and logits with the best iou indices.
192        # Note that we squash the first two axes (batch x objects) into one when indexing.
193        # That's why we need to reshape bax into (batch x objects) using a view.
194        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
195        batch_size, n_objects = masks.shape[:2]
196        h, w = masks.shape[-2:]
197        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
198
199        h, w = logits.shape[-2:]
200        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
201
202        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
203        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
204        masks = (masks > 0.0).float()
205        return masks, logits
206
207    def _use_mask_inputs(self, batched_inputs, y_one_hot):
208        # Whether to use masks per training top-iteration.
209        use_mask_inputs = False  # determines if each sub-iteration will use mask inputs as prompts or not.
210        use_zero_mask = False  # determines if the zeroth iteration will use zeros as mask inputs.
211
212        if self.mask_prob == 1:  # i.e. always use masks.
213            use_mask_inputs = True  # we would like to use mask inputs in all sub-iterations.
214            use_zero_mask = self.is_data_parallel  # we would like to use zeros as mask inputs for zeroth iteration.
215
216        elif self.mask_prob > 0:  # i.e. if we use mask inputs with a probability.
217            if self.is_data_parallel:  # if training on multiple GPUs.
218                if torch.distributed.get_rank() == 0:  # device with rank 0.
219                    use_mask_inputs_tensor = torch.tensor(
220                        random.random() < self.mask_prob, dtype=torch.uint8, device=self.device,
221                    )
222                else:  # on other devices, we do not need this parameter at this stage.
223                    use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device)
224
225                # Broadcast the value to all devices (ranks).
226                torch.distributed.broadcast(use_mask_inputs_tensor, src=0)
227
228                # And convert it back to our desired boolean value.
229                use_mask_inputs = bool(use_mask_inputs_tensor.item())
230                use_zero_mask = use_mask_inputs  # provides zeros as mask inputs.
231            else:  # training on a single GPU.
232                use_mask_inputs = None
233
234        if use_zero_mask:
235            # We use zeros as mask inputs for the zeroth iteration.
236            y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256))
237
238            # Add zeros as mask inputs to batched inputs.
239            for bi, curr_masks in zip(batched_inputs, y_zeros):
240                bi["mask_inputs"] = curr_masks
241
242        return batched_inputs, use_mask_inputs
243
244    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
245        """Compute the loss for several (sub-)iterations of iterative prompting.
246        In each iterations the prompts are updated based on the previous predictions.
247        """
248        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
249
250        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
251
252        # Whether to use mask inputs in each sub-iteration.
253        batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot)
254
255        for i in range(0, num_subiter):
256            # We do multimasking only in the first sub-iteration as we then pass single prompt
257            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
258            batched_outputs = self.model(
259                batched_inputs=batched_inputs,
260                image_embeddings=image_embeddings,
261                multimask_output=multimask_output if i == 0 else False,
262            )
263
264            # Compute loss for this sub-iteration.
265            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
266
267            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
268            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
269            with torch.no_grad():
270                net_mean_model_iou = torch.mean(batched_iou_predictions)
271
272            loss = loss + net_loss
273            mask_loss = mask_loss + net_mask_loss
274            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
275            mean_model_iou = mean_model_iou + net_mean_model_iou
276
277            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
278                # Determine the next prompts based on current predictions.
279                with torch.no_grad():
280                    # Get the mask and logit predictions corresponding to the predicted object
281                    # (per actual object) with the best IOU.
282                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
283                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs)
284
285        loss = loss / num_subiter
286        mask_loss = mask_loss / num_subiter
287        iou_regression_loss = iou_regression_loss / num_subiter
288        mean_model_iou = mean_model_iou / num_subiter
289
290        return loss, mask_loss, iou_regression_loss, mean_model_iou
291
292    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs):
293        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
294        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
295            # here, we get each object in the pairs and do the point choices per-object
296            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
297
298            # convert the point coordinates to the expected resolution for iterative prompting
299            # NOTE:
300            #   - "only" need to transform the point prompts from the iterative prompting
301            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
302            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
303
304            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
305                if "point_coords" in _inp.keys() else net_coords
306            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
307                if "point_labels" in _inp.keys() else net_labels
308
309            _inp["point_coords"] = updated_point_coords
310            _inp["point_labels"] = updated_point_labels
311
312            if self.is_data_parallel:  # multi-GPU training
313                use_mask_inputs_this_iter = use_mask_inputs
314            else:  # single GPU training
315                if self.mask_prob > 0:
316                    # using mask inputs for iterative prompting while training, with a probability
317                    use_mask_inputs = (random.random() < self.mask_prob)
318                else:  # otherwise we assume it is 0 and do not need the generator to decide.
319                    use_mask_inputs = False
320
321                use_mask_inputs_this_iter = use_mask_inputs
322
323            if use_mask_inputs_this_iter:
324                _inp["mask_inputs"] = logits
325            else:  # remove previously existing mask inputs to avoid using them in next sub-iteration.
326                _inp.pop("mask_inputs", None)
327
328        return batched_inputs
329
330    #
331    # Training Loop
332    #
333
334    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
335        """Compute one hot target (one mask per channel) for the sampled ids
336        and restrict the number of sampled objects to the minimal number in the batch.
337        """
338        assert len(y) == len(sampled_ids)
339
340        # Get the minimal number of objects in this batch.
341        # The number of objects in a patch might be < n_objects_per_batch.
342        # This is why we need to restrict it here to ensure the same
343        # number of objects across the batch.
344        n_objects = min(len(ids) for ids in sampled_ids)
345
346        y = y.to(self.device, non_blocking=True)
347        # Compute the one hot targets for the seg-id.
348        y_one_hot = torch.stack([
349            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
350            for target, ids in zip(y, sampled_ids)
351        ]).float()
352
353        # Also restrict the prompts to the number of objects.
354        batched_inputs = [
355            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
356            for inp in batched_inputs
357        ]
358        return batched_inputs, y_one_hot
359
360    def _interactive_train_iteration(self, x, y):
361        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
362
363        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
364        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
365
366        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
367            batched_inputs=batched_inputs,
368            y_one_hot=y_one_hot,
369            num_subiter=self.n_sub_iteration,
370            multimask_output=multimask_output
371        )
372        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
373
374    def _check_input_normalization(self, x, input_check_done):
375        # The expected data range of the SAM model is 8bit (0-255).
376        # It can easily happen that data is normalized beforehand in training.
377        # For some reasons we don't fully understand this still works, but it
378        # should still be avoided and is very detrimental in some settings
379        # (e.g. when freezing the image encoder)
380        # We check once per epoch if the data seems to be normalized already and
381        # raise a warning if this is the case.
382        if not input_check_done:
383            data_min, data_max = x.min(), x.max()
384            if (data_min < 0) or (data_max < 1):
385                warnings.warn(
386                    "It looks like you are normalizing the training data. "
387                    "The SAM model takes care of normalization, so it is better to not do this. "
388                    "We recommend to remove data normalization and input data in the range [0, 255]."
389                )
390            input_check_done = True
391
392        return input_check_done
393
394    def _train_epoch_impl(self, progress, forward_context, backprop):
395        self.model.train()
396
397        input_check_done = False
398
399        n_iter = 0
400        t_per_iter = time.time()
401        for x, y in self.train_loader:
402            input_check_done = self._check_input_normalization(x, input_check_done)
403
404            self.optimizer.zero_grad()
405
406            with forward_context():
407                (loss, mask_loss, iou_regression_loss, model_iou,
408                 sampled_binary_y) = self._interactive_train_iteration(x, y)
409
410            backprop(loss)
411
412            if self.logger is not None:
413                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
414                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
415                self.logger.log_train(
416                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
417                )
418
419            self._iteration += 1
420            n_iter += 1
421            if self._iteration >= self.max_iteration:
422                break
423            progress.update(1)
424
425        t_per_iter = (time.time() - t_per_iter) / n_iter
426        return t_per_iter
427
428    def _interactive_val_iteration(self, x, y, val_iteration):
429        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
430
431        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
432        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
433
434        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
435
436        batched_outputs = self.model(
437            batched_inputs=batched_inputs,
438            image_embeddings=image_embeddings,
439            multimask_output=multimask_output,
440        )
441
442        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
443        # We use the dice loss over the masks as metric.
444        metric = mask_loss
445        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
446
447        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
448
449    def _validate_impl(self, forward_context):
450        self.model.eval()
451
452        input_check_done = False
453
454        val_iteration = 0
455        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
456        mask_loss_val, iou_loss_val = 0.0, 0.0
457
458        with torch.no_grad():
459            for x, y in self.val_loader:
460                input_check_done = self._check_input_normalization(x, input_check_done)
461
462                with forward_context():
463                    (loss, mask_loss, iou_regression_loss, model_iou,
464                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
465
466                loss_val += loss.item()
467                metric_val += metric.item()
468                mask_loss_val += mask_loss.item()
469                iou_loss_val += iou_regression_loss.item()
470                model_iou_val += model_iou.item()
471                val_iteration += 1
472
473        loss_val /= len(self.val_loader)
474        metric_val /= len(self.val_loader)
475        mask_loss_val /= len(self.val_loader)
476        iou_loss_val /= len(self.val_loader)
477        model_iou_val /= len(self.val_loader)
478        print()
479        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
480
481        if self.logger is not None:
482            self.logger.log_validation(
483                self._iteration, metric_val, loss_val, x, y,
484                sampled_binary_y, mask_loss_val, iou_loss_val, model_iou_val
485            )
486
487        return metric_val

Trainer class for training the Segment Anything model.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSamInputs can be used here.
  • n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
  • n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
  • mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. By default, set to the expected mse loss function.
  • prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. Already allocated with the desired prompt generator by default.
  • mask_prob: The probability of using the mask inputs in the iterative prompting (per n_sub_iteration). By default, set to '0.5'.
  • mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs: Callable, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
43    def __init__(
44        self,
45        convert_inputs: Callable,
46        n_sub_iteration: int,
47        n_objects_per_batch: Optional[int] = None,
48        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
49        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
50        mask_prob: float = 0.5,
51        mask_loss: Optional[torch.nn.Module] = None,
52        **kwargs
53    ):
54        if mask_loss is None:
55            # We have to use the Dice Loss with reduce channel set to None.
56            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
57            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
58        else:
59            self.mask_loss = mask_loss
60
61        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
62        self.convert_inputs = convert_inputs
63        self.mse_loss = mse_loss
64        self.n_objects_per_batch = n_objects_per_batch
65        self.n_sub_iteration = n_sub_iteration
66        self.prompt_generator = prompt_generator
67        self.mask_prob = mask_prob
68        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
69        self._kwargs = kwargs
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
is_data_parallel
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit