micro_sam.training.sam_trainer

  1import os
  2import time
  3import random
  4import warnings
  5from typing import Optional, Callable
  6
  7import numpy as np
  8
  9import torch
 10from torchvision.utils import make_grid
 11
 12import torch_em
 13from torch_em.trainer.logger_base import TorchEmLogger
 14
 15from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator
 16
 17
 18class SamTrainer(torch_em.trainer.DefaultTrainer):
 19    """Trainer class for training the Segment Anything model.
 20
 21    This class is derived from `torch_em.trainer.DefaultTrainer`.
 22    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 23    for details on its usage and implementation.
 24
 25    Args:
 26        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 27            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 28        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 29            In each sub-iteration new point prompts are sampled where the model was wrong.
 30        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 31            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 32        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 33            By default, set to the expected mse loss function.
 34        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training.
 35            Already allocated with the desired prompt generator by default.
 36        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`).
 37            By default, set to '0.5'.
 38        mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
 39        kwargs: The keyword arguments of the `DefaultTrainer` super class.
 40    """
 41
 42    def __init__(
 43        self,
 44        convert_inputs: Callable,
 45        n_sub_iteration: int,
 46        n_objects_per_batch: Optional[int] = None,
 47        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 48        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 49        mask_prob: float = 0.5,
 50        mask_loss: Optional[torch.nn.Module] = None,
 51        **kwargs
 52    ):
 53        if mask_loss is None:
 54            # We have to use the Dice Loss with reduce channel set to None.
 55            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 56            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 57        else:
 58            self.mask_loss = mask_loss
 59
 60        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 61        self.convert_inputs = convert_inputs
 62        self.mse_loss = mse_loss
 63        self.n_objects_per_batch = n_objects_per_batch
 64        self.n_sub_iteration = n_sub_iteration
 65        self.prompt_generator = prompt_generator
 66        self.mask_prob = mask_prob
 67        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
 68        self._kwargs = kwargs
 69
 70    def _get_prompt_and_multimasking_choices(self, current_iteration):
 71        """Choose the type of prompts we sample for training, and then we call
 72        'convert_inputs' with the correct prompting from here.
 73        """
 74        if current_iteration % 2 == 0:  # sample only a single point per object
 75            n_pos, n_neg = 1, 0
 76            get_boxes = False
 77            multimask_output = True
 78
 79        else:  # sample only a single box per object
 80            n_pos, n_neg = 0, 0
 81            get_boxes = True
 82            multimask_output = False
 83
 84        return n_pos, n_neg, get_boxes, multimask_output
 85
 86    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 87        """Choose the type of prompts we sample for validation, and then we call
 88        'convert_inputs' with the correct prompting from here.
 89        """
 90        if current_iteration % 4 == 0:  # sample only a single point per object
 91            n_pos, n_neg = 1, 0
 92            get_boxes = False
 93            multimask_output = True
 94
 95        elif current_iteration % 4 == 1:  # sample only a single box per object
 96            n_pos, n_neg = 0, 0
 97            get_boxes = True
 98            multimask_output = False
 99
100        elif current_iteration % 4 == 2:  # sample a random no. of points
101            pos_range, neg_range = 4, 4
102
103            n_pos = np.random.randint(1, pos_range + 1)
104            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
105                n_neg = np.random.randint(1, neg_range + 1)
106            else:
107                n_neg = np.random.randint(0, neg_range + 1)
108            get_boxes = False
109            multimask_output = False
110
111        else:  # sample boxes AND random no. of points
112            # here we can have (1, 0) because we also have box
113            pos_range, neg_range = 4, 4
114
115            n_pos = np.random.randint(1, pos_range + 1)
116            n_neg = np.random.randint(0, neg_range + 1)
117            get_boxes = True
118            multimask_output = False
119
120        return n_pos, n_neg, get_boxes, multimask_output
121
122    def _compute_iou(self, pred, true, eps=1e-7):
123        """Compute the IoU score between the prediction and target.
124        """
125        pred_mask = pred > 0.5  # binarizing the output predictions
126        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
127        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
128        iou = overlap / (union + eps)
129        return iou
130
131    def _compute_loss(self, batched_outputs, y_one_hot):
132        """Compute the loss for one iteration. The loss is made up of two components:
133        - The mask loss: dice score between the predicted masks and targets.
134        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
135        """
136        mask_loss, iou_regression_loss = 0.0, 0.0
137
138        # Loop over the batch.
139        for batch_output, targets in zip(batched_outputs, y_one_hot):
140
141            predicted_objects = torch.sigmoid(batch_output["masks"])
142            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
143            # We swap the axes that go into the dice loss so that the object axis
144            # corresponds to the channel axes. This ensures that the dice is computed
145            # independetly per channel. We do not reduce the channel axis in the dice,
146            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
147            dice_scores = torch.stack([
148                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
149                for i in range(predicted_objects.shape[1])
150            ])
151            dice_scores, _ = torch.min(dice_scores, dim=0)
152
153            # Compute the actual IOU between the predicted and true objects.
154            # The outer loop is for the 1 or 3 predicted masks per true object.
155            with torch.no_grad():
156                true_iou = torch.stack([
157                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
158                ])
159            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
160            # the object axis is back in the first dimension.
161            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
162
163            mask_loss = mask_loss + torch.mean(dice_scores)
164            iou_regression_loss = iou_regression_loss + iou_score
165
166        loss = mask_loss + iou_regression_loss
167
168        return loss, mask_loss, iou_regression_loss
169
170    #
171    # Functionality for iterative prompting loss
172    #
173
174    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
175        # Batched mask and logit (low-res mask) predictions.
176        masks = torch.stack([m["masks"] for m in batched_outputs])
177        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
178
179        # Determine the best IOU across the multi-object prediction axis
180        # and turn this into a mask we can use for indexing.
181        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
182        # for details on the indexing logic.
183        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
184        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
185
186        # Index the mask and logits with the best iou indices.
187        # Note that we squash the first two axes (batch x objects) into one when indexing.
188        # That's why we need to reshape bax into (batch x objects) using a view.
189        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
190        batch_size, n_objects = masks.shape[:2]
191        h, w = masks.shape[-2:]
192        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
193
194        h, w = logits.shape[-2:]
195        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
196
197        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
198        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
199        masks = (masks > 0.0).float()
200        return masks, logits
201
202    def _use_mask_inputs(self, batched_inputs, y_one_hot):
203        # Whether to use masks per training top-iteration.
204        use_mask_inputs = False  # determines if each sub-iteration will use mask inputs as prompts or not.
205        use_zero_mask = False  # determines if the zeroth iteration will use zeros as mask inputs.
206
207        if self.mask_prob == 1:  # i.e. always use masks.
208            use_mask_inputs = True  # we would like to use mask inputs in all sub-iterations.
209            use_zero_mask = self.is_data_parallel  # we would like to use zeros as mask inputs for zeroth iteration.
210
211        elif self.mask_prob > 0:  # i.e. if we use mask inputs with a probability.
212            if self.is_data_parallel:  # if training on multiple GPUs.
213                if torch.distributed.get_rank() == 0:  # device with rank 0.
214                    use_mask_inputs_tensor = torch.tensor(
215                        random.random() < self.mask_prob, dtype=torch.uint8, device=self.device,
216                    )
217                else:  # on other devices, we do not need this parameter at this stage.
218                    use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device)
219
220                # Broadcast the value to all devices (ranks).
221                torch.distributed.broadcast(use_mask_inputs_tensor, src=0)
222
223                # And convert it back to our desired boolean value.
224                use_mask_inputs = bool(use_mask_inputs_tensor.item())
225                use_zero_mask = use_mask_inputs  # provides zeros as mask inputs.
226            else:  # training on a single GPU.
227                use_mask_inputs = None
228
229        if use_zero_mask:
230            # We use zeros as mask inputs for the zeroth iteration.
231            y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256))
232
233            # Add zeros as mask inputs to batched inputs.
234            for bi, curr_masks in zip(batched_inputs, y_zeros):
235                bi["mask_inputs"] = curr_masks
236
237        return batched_inputs, use_mask_inputs
238
239    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
240        """Compute the loss for several (sub-)iterations of iterative prompting.
241        In each iterations the prompts are updated based on the previous predictions.
242        """
243        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
244
245        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
246
247        # Whether to use mask inputs in each sub-iteration.
248        batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot)
249
250        for i in range(0, num_subiter):
251            # We do multimasking only in the first sub-iteration as we then pass single prompt
252            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
253            batched_outputs = self.model(
254                batched_inputs=batched_inputs,
255                image_embeddings=image_embeddings,
256                multimask_output=multimask_output if i == 0 else False,
257            )
258
259            # Compute loss for this sub-iteration.
260            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
261
262            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
263            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
264            with torch.no_grad():
265                net_mean_model_iou = torch.mean(batched_iou_predictions)
266
267            loss = loss + net_loss
268            mask_loss = mask_loss + net_mask_loss
269            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
270            mean_model_iou = mean_model_iou + net_mean_model_iou
271
272            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
273                # Determine the next prompts based on current predictions.
274                with torch.no_grad():
275                    # Get the mask and logit predictions corresponding to the predicted object
276                    # (per actual object) with the best IOU.
277                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
278                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs)
279
280        loss = loss / num_subiter
281        mask_loss = mask_loss / num_subiter
282        iou_regression_loss = iou_regression_loss / num_subiter
283        mean_model_iou = mean_model_iou / num_subiter
284
285        return loss, mask_loss, iou_regression_loss, mean_model_iou
286
287    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs):
288        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
289        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
290            # here, we get each object in the pairs and do the point choices per-object
291            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
292
293            # convert the point coordinates to the expected resolution for iterative prompting
294            # NOTE:
295            #   - "only" need to transform the point prompts from the iterative prompting
296            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
297            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
298
299            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
300                if "point_coords" in _inp.keys() else net_coords
301            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
302                if "point_labels" in _inp.keys() else net_labels
303
304            _inp["point_coords"] = updated_point_coords
305            _inp["point_labels"] = updated_point_labels
306
307            if self.is_data_parallel:  # multi-GPU training
308                use_mask_inputs_this_iter = use_mask_inputs
309            else:  # single GPU training
310                if self.mask_prob > 0:
311                    # using mask inputs for iterative prompting while training, with a probability
312                    use_mask_inputs = (random.random() < self.mask_prob)
313                else:  # otherwise we assume it is 0 and do not need the generator to decide.
314                    use_mask_inputs = False
315
316                use_mask_inputs_this_iter = use_mask_inputs
317
318            if use_mask_inputs_this_iter:
319                _inp["mask_inputs"] = logits
320            else:  # remove previously existing mask inputs to avoid using them in next sub-iteration.
321                _inp.pop("mask_inputs", None)
322
323        return batched_inputs
324
325    #
326    # Training Loop
327    #
328
329    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
330        """Compute one hot target (one mask per channel) for the sampled ids
331        and restrict the number of sampled objects to the minimal number in the batch.
332        """
333        assert len(y) == len(sampled_ids)
334
335        # Get the minimal number of objects in this batch.
336        # The number of objects in a patch might be < n_objects_per_batch.
337        # This is why we need to restrict it here to ensure the same
338        # number of objects across the batch.
339        n_objects = min(len(ids) for ids in sampled_ids)
340
341        y = y.to(self.device, non_blocking=True)
342        # Compute the one hot targets for the seg-id.
343        y_one_hot = torch.stack([
344            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
345            for target, ids in zip(y, sampled_ids)
346        ]).float()
347
348        # Also restrict the prompts to the number of objects.
349        batched_inputs = [
350            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
351            for inp in batched_inputs
352        ]
353        return batched_inputs, y_one_hot
354
355    def _interactive_train_iteration(self, x, y):
356        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
357
358        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
359        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
360
361        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
362            batched_inputs=batched_inputs,
363            y_one_hot=y_one_hot,
364            num_subiter=self.n_sub_iteration,
365            multimask_output=multimask_output
366        )
367        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
368
369    def _check_input_normalization(self, x, input_check_done):
370        # The expected data range of the SAM model is 8bit (0-255).
371        # It can easily happen that data is normalized beforehand in training.
372        # For some reasons we don't fully understand this still works, but it
373        # should still be avoided and is very detrimental in some settings
374        # (e.g. when freezing the image encoder)
375        # We check once per epoch if the data seems to be normalized already and
376        # raise a warning if this is the case.
377        if not input_check_done:
378            data_min, data_max = x.min(), x.max()
379            if (data_min < 0) or (data_max < 1):
380                warnings.warn(
381                    "It looks like you are normalizing the training data. "
382                    "The SAM model takes care of normalization, so it is better to not do this. "
383                    "We recommend to remove data normalization and input data in the range [0, 255]."
384                )
385            input_check_done = True
386
387        return input_check_done
388
389    def _train_epoch_impl(self, progress, forward_context, backprop):
390        self.model.train()
391
392        input_check_done = False
393
394        n_iter = 0
395        t_per_iter = time.time()
396        for x, y in self.train_loader:
397            input_check_done = self._check_input_normalization(x, input_check_done)
398
399            self.optimizer.zero_grad()
400
401            with forward_context():
402                (loss, mask_loss, iou_regression_loss, model_iou,
403                 sampled_binary_y) = self._interactive_train_iteration(x, y)
404
405            backprop(loss)
406
407            if self.logger is not None:
408                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
409                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
410                self.logger.log_train(
411                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
412                )
413
414            self._iteration += 1
415            n_iter += 1
416            if self._iteration >= self.max_iteration:
417                break
418            progress.update(1)
419
420        t_per_iter = (time.time() - t_per_iter) / n_iter
421        return t_per_iter
422
423    def _interactive_val_iteration(self, x, y, val_iteration):
424        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
425
426        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
427        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
428
429        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
430
431        batched_outputs = self.model(
432            batched_inputs=batched_inputs,
433            image_embeddings=image_embeddings,
434            multimask_output=multimask_output,
435        )
436
437        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
438        # We use the dice loss over the masks as metric.
439        metric = mask_loss
440        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
441
442        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
443
444    def _validate_impl(self, forward_context):
445        self.model.eval()
446
447        input_check_done = False
448
449        val_iteration = 0
450        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
451
452        with torch.no_grad():
453            for x, y in self.val_loader:
454                input_check_done = self._check_input_normalization(x, input_check_done)
455
456                with forward_context():
457                    (loss, mask_loss, iou_regression_loss, model_iou,
458                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
459
460                loss_val += loss.item()
461                metric_val += metric.item()
462                model_iou_val += model_iou.item()
463                val_iteration += 1
464
465        loss_val /= len(self.val_loader)
466        metric_val /= len(self.val_loader)
467        model_iou_val /= len(self.val_loader)
468        print()
469        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
470
471        if self.logger is not None:
472            self.logger.log_validation(
473                self._iteration, metric_val, loss_val, x, y,
474                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
475            )
476
477        return metric_val
478
479
480class SamLogger(TorchEmLogger):
481    """@private"""
482    def __init__(self, trainer, save_root, **unused_kwargs):
483        super().__init__(trainer, save_root)
484        self.log_dir = f"./logs/{trainer.name}" if save_root is None else os.path.join(save_root, "logs", trainer.name)
485        os.makedirs(self.log_dir, exist_ok=True)
486
487        self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
488        self.log_image_interval = trainer.log_image_interval
489
490    def add_image(self, x, y, samples, name, step):
491        self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
492        self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
493        sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
494        self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)
495
496    def log_train(self, step, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou):
497        self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
498        self.tb.add_scalar(tag="train/mask_loss", scalar_value=mask_loss, global_step=step)
499        self.tb.add_scalar(tag="train/iou_loss", scalar_value=iou_regression_loss, global_step=step)
500        self.tb.add_scalar(tag="train/model_iou", scalar_value=model_iou, global_step=step)
501        self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
502        if step % self.log_image_interval == 0:
503            self.add_image(x, y, samples, "train", step)
504
505    def log_validation(self, step, metric, loss, x, y, samples, mask_loss, iou_regression_loss, model_iou):
506        self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
507        self.tb.add_scalar(tag="validation/mask_loss", scalar_value=mask_loss, global_step=step)
508        self.tb.add_scalar(tag="validation/iou_loss", scalar_value=iou_regression_loss, global_step=step)
509        self.tb.add_scalar(tag="validation/model_iou", scalar_value=model_iou, global_step=step)
510        self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
511        self.add_image(x, y, samples, "validation", step)
class SamTrainer(torch_em.trainer.default_trainer.DefaultTrainer):
 19class SamTrainer(torch_em.trainer.DefaultTrainer):
 20    """Trainer class for training the Segment Anything model.
 21
 22    This class is derived from `torch_em.trainer.DefaultTrainer`.
 23    Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py
 24    for details on its usage and implementation.
 25
 26    Args:
 27        convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM.
 28            The class `micro_sam.training.util.ConvertToSamInputs` can be used here.
 29        n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated.
 30            In each sub-iteration new point prompts are sampled where the model was wrong.
 31        n_objects_per_batch: If not given, we compute the loss for all objects in a sample.
 32            Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
 33        mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU.
 34            By default, set to the expected mse loss function.
 35        prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training.
 36            Already allocated with the desired prompt generator by default.
 37        mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`).
 38            By default, set to '0.5'.
 39        mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
 40        kwargs: The keyword arguments of the `DefaultTrainer` super class.
 41    """
 42
 43    def __init__(
 44        self,
 45        convert_inputs: Callable,
 46        n_sub_iteration: int,
 47        n_objects_per_batch: Optional[int] = None,
 48        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
 49        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
 50        mask_prob: float = 0.5,
 51        mask_loss: Optional[torch.nn.Module] = None,
 52        **kwargs
 53    ):
 54        if mask_loss is None:
 55            # We have to use the Dice Loss with reduce channel set to None.
 56            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
 57            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
 58        else:
 59            self.mask_loss = mask_loss
 60
 61        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
 62        self.convert_inputs = convert_inputs
 63        self.mse_loss = mse_loss
 64        self.n_objects_per_batch = n_objects_per_batch
 65        self.n_sub_iteration = n_sub_iteration
 66        self.prompt_generator = prompt_generator
 67        self.mask_prob = mask_prob
 68        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
 69        self._kwargs = kwargs
 70
 71    def _get_prompt_and_multimasking_choices(self, current_iteration):
 72        """Choose the type of prompts we sample for training, and then we call
 73        'convert_inputs' with the correct prompting from here.
 74        """
 75        if current_iteration % 2 == 0:  # sample only a single point per object
 76            n_pos, n_neg = 1, 0
 77            get_boxes = False
 78            multimask_output = True
 79
 80        else:  # sample only a single box per object
 81            n_pos, n_neg = 0, 0
 82            get_boxes = True
 83            multimask_output = False
 84
 85        return n_pos, n_neg, get_boxes, multimask_output
 86
 87    def _get_prompt_and_multimasking_choices_for_val(self, current_iteration):
 88        """Choose the type of prompts we sample for validation, and then we call
 89        'convert_inputs' with the correct prompting from here.
 90        """
 91        if current_iteration % 4 == 0:  # sample only a single point per object
 92            n_pos, n_neg = 1, 0
 93            get_boxes = False
 94            multimask_output = True
 95
 96        elif current_iteration % 4 == 1:  # sample only a single box per object
 97            n_pos, n_neg = 0, 0
 98            get_boxes = True
 99            multimask_output = False
100
101        elif current_iteration % 4 == 2:  # sample a random no. of points
102            pos_range, neg_range = 4, 4
103
104            n_pos = np.random.randint(1, pos_range + 1)
105            if n_pos == 1:  # to avoid (1, 0) combination for redundancy but still have (n_pos, 0)
106                n_neg = np.random.randint(1, neg_range + 1)
107            else:
108                n_neg = np.random.randint(0, neg_range + 1)
109            get_boxes = False
110            multimask_output = False
111
112        else:  # sample boxes AND random no. of points
113            # here we can have (1, 0) because we also have box
114            pos_range, neg_range = 4, 4
115
116            n_pos = np.random.randint(1, pos_range + 1)
117            n_neg = np.random.randint(0, neg_range + 1)
118            get_boxes = True
119            multimask_output = False
120
121        return n_pos, n_neg, get_boxes, multimask_output
122
123    def _compute_iou(self, pred, true, eps=1e-7):
124        """Compute the IoU score between the prediction and target.
125        """
126        pred_mask = pred > 0.5  # binarizing the output predictions
127        overlap = pred_mask.logical_and(true).sum(dim=(1, 2, 3))
128        union = pred_mask.logical_or(true).sum(dim=(1, 2, 3))
129        iou = overlap / (union + eps)
130        return iou
131
132    def _compute_loss(self, batched_outputs, y_one_hot):
133        """Compute the loss for one iteration. The loss is made up of two components:
134        - The mask loss: dice score between the predicted masks and targets.
135        - The IOU loss: L2 loss between the predicted IOU and the actual IOU of prediction and target.
136        """
137        mask_loss, iou_regression_loss = 0.0, 0.0
138
139        # Loop over the batch.
140        for batch_output, targets in zip(batched_outputs, y_one_hot):
141
142            predicted_objects = torch.sigmoid(batch_output["masks"])
143            # Compute the dice scores for the 1 or 3 predicted masks per true object (outer loop).
144            # We swap the axes that go into the dice loss so that the object axis
145            # corresponds to the channel axes. This ensures that the dice is computed
146            # independetly per channel. We do not reduce the channel axis in the dice,
147            # so that we can take the minimum (best score) of the 1/3 predicted masks per object.
148            dice_scores = torch.stack([
149                self.loss(predicted_objects[:, i:i+1].swapaxes(0, 1), targets.swapaxes(0, 1))
150                for i in range(predicted_objects.shape[1])
151            ])
152            dice_scores, _ = torch.min(dice_scores, dim=0)
153
154            # Compute the actual IOU between the predicted and true objects.
155            # The outer loop is for the 1 or 3 predicted masks per true object.
156            with torch.no_grad():
157                true_iou = torch.stack([
158                    self._compute_iou(predicted_objects[:, i:i+1], targets) for i in range(predicted_objects.shape[1])
159                ])
160            # Compute the L2 loss between true and predicted IOU. We need to swap the axes so that
161            # the object axis is back in the first dimension.
162            iou_score = self.mse_loss(true_iou.swapaxes(0, 1), batch_output["iou_predictions"])
163
164            mask_loss = mask_loss + torch.mean(dice_scores)
165            iou_regression_loss = iou_regression_loss + iou_score
166
167        loss = mask_loss + iou_regression_loss
168
169        return loss, mask_loss, iou_regression_loss
170
171    #
172    # Functionality for iterative prompting loss
173    #
174
175    def _get_best_masks(self, batched_outputs, batched_iou_predictions):
176        # Batched mask and logit (low-res mask) predictions.
177        masks = torch.stack([m["masks"] for m in batched_outputs])
178        logits = torch.stack([m["low_res_masks"] for m in batched_outputs])
179
180        # Determine the best IOU across the multi-object prediction axis
181        # and turn this into a mask we can use for indexing.
182        # See https://stackoverflow.com/questions/72628000/pytorch-indexing-by-argmax
183        # for details on the indexing logic.
184        best_iou_idx = torch.argmax(batched_iou_predictions, dim=2, keepdim=True)
185        best_iou_idx = torch.zeros_like(batched_iou_predictions).scatter(2, best_iou_idx, value=1).bool()
186
187        # Index the mask and logits with the best iou indices.
188        # Note that we squash the first two axes (batch x objects) into one when indexing.
189        # That's why we need to reshape bax into (batch x objects) using a view.
190        # We also keep the multi object axis as a singleton, that's why the view has (batch_size, n_objects, 1, ...)
191        batch_size, n_objects = masks.shape[:2]
192        h, w = masks.shape[-2:]
193        masks = masks[best_iou_idx].view(batch_size, n_objects, 1, h, w)
194
195        h, w = logits.shape[-2:]
196        logits = logits[best_iou_idx].view(batch_size, n_objects, 1, h, w)
197
198        # Binarize the mask. Note that the mask here also contains logits, so we use 0.0
199        # as threshold instead of using 0.5. (Hence we don't need to apply a sigmoid)
200        masks = (masks > 0.0).float()
201        return masks, logits
202
203    def _use_mask_inputs(self, batched_inputs, y_one_hot):
204        # Whether to use masks per training top-iteration.
205        use_mask_inputs = False  # determines if each sub-iteration will use mask inputs as prompts or not.
206        use_zero_mask = False  # determines if the zeroth iteration will use zeros as mask inputs.
207
208        if self.mask_prob == 1:  # i.e. always use masks.
209            use_mask_inputs = True  # we would like to use mask inputs in all sub-iterations.
210            use_zero_mask = self.is_data_parallel  # we would like to use zeros as mask inputs for zeroth iteration.
211
212        elif self.mask_prob > 0:  # i.e. if we use mask inputs with a probability.
213            if self.is_data_parallel:  # if training on multiple GPUs.
214                if torch.distributed.get_rank() == 0:  # device with rank 0.
215                    use_mask_inputs_tensor = torch.tensor(
216                        random.random() < self.mask_prob, dtype=torch.uint8, device=self.device,
217                    )
218                else:  # on other devices, we do not need this parameter at this stage.
219                    use_mask_inputs_tensor = torch.tensor(0, dtype=torch.uint8, device=self.device)
220
221                # Broadcast the value to all devices (ranks).
222                torch.distributed.broadcast(use_mask_inputs_tensor, src=0)
223
224                # And convert it back to our desired boolean value.
225                use_mask_inputs = bool(use_mask_inputs_tensor.item())
226                use_zero_mask = use_mask_inputs  # provides zeros as mask inputs.
227            else:  # training on a single GPU.
228                use_mask_inputs = None
229
230        if use_zero_mask:
231            # We use zeros as mask inputs for the zeroth iteration.
232            y_zeros = torch.zeros((*y_one_hot.shape[:3], 256, 256))
233
234            # Add zeros as mask inputs to batched inputs.
235            for bi, curr_masks in zip(batched_inputs, y_zeros):
236                bi["mask_inputs"] = curr_masks
237
238        return batched_inputs, use_mask_inputs
239
240    def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multimask_output):
241        """Compute the loss for several (sub-)iterations of iterative prompting.
242        In each iterations the prompts are updated based on the previous predictions.
243        """
244        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
245
246        loss, mask_loss, iou_regression_loss, mean_model_iou = 0.0, 0.0, 0.0, 0.0
247
248        # Whether to use mask inputs in each sub-iteration.
249        batched_inputs, use_mask_inputs = self._use_mask_inputs(batched_inputs, y_one_hot)
250
251        for i in range(0, num_subiter):
252            # We do multimasking only in the first sub-iteration as we then pass single prompt
253            # after the first sub-iteration, we don't do multimasking because we get multiple prompts.
254            batched_outputs = self.model(
255                batched_inputs=batched_inputs,
256                image_embeddings=image_embeddings,
257                multimask_output=multimask_output if i == 0 else False,
258            )
259
260            # Compute loss for this sub-iteration.
261            net_loss, net_mask_loss, net_iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
262
263            # Compute the mean IOU predicted by the model. We keep track of this in the logger.
264            batched_iou_predictions = torch.stack([m["iou_predictions"] for m in batched_outputs])
265            with torch.no_grad():
266                net_mean_model_iou = torch.mean(batched_iou_predictions)
267
268            loss = loss + net_loss
269            mask_loss = mask_loss + net_mask_loss
270            iou_regression_loss = iou_regression_loss + net_iou_regression_loss
271            mean_model_iou = mean_model_iou + net_mean_model_iou
272
273            if i < (num_subiter - 1):   # We need not update the prompts for the last iteration.
274                # Determine the next prompts based on current predictions.
275                with torch.no_grad():
276                    # Get the mask and logit predictions corresponding to the predicted object
277                    # (per actual object) with the best IOU.
278                    masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions)
279                    batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits, use_mask_inputs)
280
281        loss = loss / num_subiter
282        mask_loss = mask_loss / num_subiter
283        iou_regression_loss = iou_regression_loss / num_subiter
284        mean_model_iou = mean_model_iou / num_subiter
285
286        return loss, mask_loss, iou_regression_loss, mean_model_iou
287
288    def _update_prompts(self, batched_inputs, y_one_hot, masks, logits_masks, use_mask_inputs):
289        # here, we get the pair-per-batch of predicted and true elements (and also the "batched_inputs")
290        for x1, x2, _inp, logits in zip(masks, y_one_hot, batched_inputs, logits_masks):
291            # here, we get each object in the pairs and do the point choices per-object
292            net_coords, net_labels, _, _ = self.prompt_generator(x2, x1)
293
294            # convert the point coordinates to the expected resolution for iterative prompting
295            # NOTE:
296            #   - "only" need to transform the point prompts from the iterative prompting
297            #   - the `logits` are the low res masks (256, 256), hence do not need the transform
298            net_coords = self.model.transform.apply_coords_torch(net_coords, y_one_hot.shape[-2:])
299
300            updated_point_coords = torch.cat([_inp["point_coords"], net_coords], dim=1) \
301                if "point_coords" in _inp.keys() else net_coords
302            updated_point_labels = torch.cat([_inp["point_labels"], net_labels], dim=1) \
303                if "point_labels" in _inp.keys() else net_labels
304
305            _inp["point_coords"] = updated_point_coords
306            _inp["point_labels"] = updated_point_labels
307
308            if self.is_data_parallel:  # multi-GPU training
309                use_mask_inputs_this_iter = use_mask_inputs
310            else:  # single GPU training
311                if self.mask_prob > 0:
312                    # using mask inputs for iterative prompting while training, with a probability
313                    use_mask_inputs = (random.random() < self.mask_prob)
314                else:  # otherwise we assume it is 0 and do not need the generator to decide.
315                    use_mask_inputs = False
316
317                use_mask_inputs_this_iter = use_mask_inputs
318
319            if use_mask_inputs_this_iter:
320                _inp["mask_inputs"] = logits
321            else:  # remove previously existing mask inputs to avoid using them in next sub-iteration.
322                _inp.pop("mask_inputs", None)
323
324        return batched_inputs
325
326    #
327    # Training Loop
328    #
329
330    def _preprocess_batch(self, batched_inputs, y, sampled_ids):
331        """Compute one hot target (one mask per channel) for the sampled ids
332        and restrict the number of sampled objects to the minimal number in the batch.
333        """
334        assert len(y) == len(sampled_ids)
335
336        # Get the minimal number of objects in this batch.
337        # The number of objects in a patch might be < n_objects_per_batch.
338        # This is why we need to restrict it here to ensure the same
339        # number of objects across the batch.
340        n_objects = min(len(ids) for ids in sampled_ids)
341
342        y = y.to(self.device, non_blocking=True)
343        # Compute the one hot targets for the seg-id.
344        y_one_hot = torch.stack([
345            torch.stack([target == seg_id for seg_id in ids[:n_objects]])
346            for target, ids in zip(y, sampled_ids)
347        ]).float()
348
349        # Also restrict the prompts to the number of objects.
350        batched_inputs = [
351            {k: (v[:n_objects] if k in ("point_coords", "point_labels", "boxes") else v) for k, v in inp.items()}
352            for inp in batched_inputs
353        ]
354        return batched_inputs, y_one_hot
355
356    def _interactive_train_iteration(self, x, y):
357        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
358
359        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
360        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
361
362        loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
363            batched_inputs=batched_inputs,
364            y_one_hot=y_one_hot,
365            num_subiter=self.n_sub_iteration,
366            multimask_output=multimask_output
367        )
368        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot
369
370    def _check_input_normalization(self, x, input_check_done):
371        # The expected data range of the SAM model is 8bit (0-255).
372        # It can easily happen that data is normalized beforehand in training.
373        # For some reasons we don't fully understand this still works, but it
374        # should still be avoided and is very detrimental in some settings
375        # (e.g. when freezing the image encoder)
376        # We check once per epoch if the data seems to be normalized already and
377        # raise a warning if this is the case.
378        if not input_check_done:
379            data_min, data_max = x.min(), x.max()
380            if (data_min < 0) or (data_max < 1):
381                warnings.warn(
382                    "It looks like you are normalizing the training data. "
383                    "The SAM model takes care of normalization, so it is better to not do this. "
384                    "We recommend to remove data normalization and input data in the range [0, 255]."
385                )
386            input_check_done = True
387
388        return input_check_done
389
390    def _train_epoch_impl(self, progress, forward_context, backprop):
391        self.model.train()
392
393        input_check_done = False
394
395        n_iter = 0
396        t_per_iter = time.time()
397        for x, y in self.train_loader:
398            input_check_done = self._check_input_normalization(x, input_check_done)
399
400            self.optimizer.zero_grad()
401
402            with forward_context():
403                (loss, mask_loss, iou_regression_loss, model_iou,
404                 sampled_binary_y) = self._interactive_train_iteration(x, y)
405
406            backprop(loss)
407
408            if self.logger is not None:
409                lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
410                samples = sampled_binary_y if self._iteration % self.log_image_interval == 0 else None
411                self.logger.log_train(
412                    self._iteration, loss, lr, x, y, samples, mask_loss, iou_regression_loss, model_iou
413                )
414
415            self._iteration += 1
416            n_iter += 1
417            if self._iteration >= self.max_iteration:
418                break
419            progress.update(1)
420
421        t_per_iter = (time.time() - t_per_iter) / n_iter
422        return t_per_iter
423
424    def _interactive_val_iteration(self, x, y, val_iteration):
425        n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices_for_val(val_iteration)
426
427        batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
428        batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
429
430        image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
431
432        batched_outputs = self.model(
433            batched_inputs=batched_inputs,
434            image_embeddings=image_embeddings,
435            multimask_output=multimask_output,
436        )
437
438        loss, mask_loss, iou_regression_loss = self._compute_loss(batched_outputs, y_one_hot)
439        # We use the dice loss over the masks as metric.
440        metric = mask_loss
441        model_iou = torch.mean(torch.stack([m["iou_predictions"] for m in batched_outputs]))
442
443        return loss, mask_loss, iou_regression_loss, model_iou, y_one_hot, metric
444
445    def _validate_impl(self, forward_context):
446        self.model.eval()
447
448        input_check_done = False
449
450        val_iteration = 0
451        metric_val, loss_val, model_iou_val = 0.0, 0.0, 0.0
452
453        with torch.no_grad():
454            for x, y in self.val_loader:
455                input_check_done = self._check_input_normalization(x, input_check_done)
456
457                with forward_context():
458                    (loss, mask_loss, iou_regression_loss, model_iou,
459                     sampled_binary_y, metric) = self._interactive_val_iteration(x, y, val_iteration)
460
461                loss_val += loss.item()
462                metric_val += metric.item()
463                model_iou_val += model_iou.item()
464                val_iteration += 1
465
466        loss_val /= len(self.val_loader)
467        metric_val /= len(self.val_loader)
468        model_iou_val /= len(self.val_loader)
469        print()
470        print(f"The Average Dice Score for the Current Epoch is {1 - metric_val}")
471
472        if self.logger is not None:
473            self.logger.log_validation(
474                self._iteration, metric_val, loss_val, x, y,
475                sampled_binary_y, mask_loss, iou_regression_loss, model_iou_val
476            )
477
478        return metric_val

Trainer class for training the Segment Anything model.

This class is derived from torch_em.trainer.DefaultTrainer. Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py for details on its usage and implementation.

Arguments:
  • convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. The class micro_sam.training.util.ConvertToSamInputs can be used here.
  • n_sub_iteration: The number of iteration steps for which the masks predicted for one object are updated. In each sub-iteration new point prompts are sampled where the model was wrong.
  • n_objects_per_batch: If not given, we compute the loss for all objects in a sample. Otherwise the loss computation is limited to n_objects_per_batch, and the objects are randomly sampled.
  • mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. By default, set to the expected mse loss function.
  • prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training. Already allocated with the desired prompt generator by default.
  • mask_prob: The probability of using the mask inputs in the iterative prompting (per n_sub_iteration). By default, set to '0.5'.
  • mask_loss: The loss to compare the predicted masks and the targets. By default, set to the dice loss function.
  • kwargs: The keyword arguments of the DefaultTrainer super class.
SamTrainer( convert_inputs: Callable, n_sub_iteration: int, n_objects_per_batch: Optional[int] = None, mse_loss: torch.nn.modules.module.Module = MSELoss(), prompt_generator: micro_sam.prompt_generators.PromptGeneratorBase = <micro_sam.prompt_generators.IterativePromptGenerator object>, mask_prob: float = 0.5, mask_loss: Optional[torch.nn.modules.module.Module] = None, **kwargs)
43    def __init__(
44        self,
45        convert_inputs: Callable,
46        n_sub_iteration: int,
47        n_objects_per_batch: Optional[int] = None,
48        mse_loss: torch.nn.Module = torch.nn.MSELoss(),
49        prompt_generator: PromptGeneratorBase = IterativePromptGenerator(),
50        mask_prob: float = 0.5,
51        mask_loss: Optional[torch.nn.Module] = None,
52        **kwargs
53    ):
54        if mask_loss is None:
55            # We have to use the Dice Loss with reduce channel set to None.
56            # Hence we hard-code it here to avoid issues by passsing wrong options for the loss.
57            self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None)
58        else:
59            self.mask_loss = mask_loss
60
61        super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs)
62        self.convert_inputs = convert_inputs
63        self.mse_loss = mse_loss
64        self.n_objects_per_batch = n_objects_per_batch
65        self.n_sub_iteration = n_sub_iteration
66        self.prompt_generator = prompt_generator
67        self.mask_prob = mask_prob
68        self.is_data_parallel = torch.distributed.is_available() and torch.distributed.is_initialized()
69        self._kwargs = kwargs
convert_inputs
mse_loss
n_objects_per_batch
n_sub_iteration
prompt_generator
mask_prob
is_data_parallel
Inherited Members
torch_em.trainer.default_trainer.DefaultTrainer
name
id_
train_loader
val_loader
model
loss
optimizer
metric
device
lr_scheduler
log_image_interval
save_root
compile_model
rank
mixed_precision
early_stopping
train_time
logger_class
logger_kwargs
checkpoint_folder
iteration
epoch
Deserializer
Serializer
fit