micro_sam.training.trainable_sam

  1from typing import Any, Dict, List, Tuple
  2
  3import torch
  4from torch import nn
  5from torch.nn import functional as F
  6
  7from segment_anything.modeling import Sam
  8from segment_anything.utils.transforms import ResizeLongestSide
  9
 10
 11# simple wrapper around SAM in order to keep things trainable
 12class TrainableSAM(nn.Module):
 13    """Wrapper to make the SegmentAnything model trainable.
 14
 15    Args:
 16        sam: The SegmentAnything Model.
 17    """
 18    def __init__(self, sam: Sam) -> None:
 19        super().__init__()
 20        self.sam = sam
 21        self.transform = ResizeLongestSide(sam.image_encoder.img_size)
 22
 23    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
 24        """Resize, normalize pixel values and pad to a square input.
 25
 26        Args:
 27            x: The input tensor.
 28
 29        Returns:
 30            The resized, normalized and padded tensor.
 31            The shape of the image after resizing.
 32        """
 33
 34        # Resize longest side to match the image encoder.
 35        x = self.transform.apply_image_torch(x)
 36        input_size = x.shape[-2:]
 37
 38        # Normalize colors
 39        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
 40
 41        # Pad
 42        h, w = x.shape[-2:]
 43        padh = self.sam.image_encoder.img_size - h
 44        padw = self.sam.image_encoder.img_size - w
 45        x = F.pad(x, (0, padw, 0, padh))
 46        return x, input_size
 47
 48    def image_embeddings_oft(self, batched_inputs):
 49        # Compute the input images.
 50        input_images, input_size = self.preprocess(
 51            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
 52        )
 53        # Update the input size for each input in the batch.
 54        for i in range(len(batched_inputs)):
 55            batched_inputs[i]["input_size"] = input_size
 56        # Compute the image embeddings.
 57        image_embeddings = self.sam.image_encoder(input_images)
 58        return image_embeddings, batched_inputs
 59
 60    # batched inputs follow the same syntax as the input to sam.forward
 61    def forward(
 62        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 63    ) -> List[Dict[str, Any]]:
 64        """Forward pass.
 65
 66        Args:
 67            batched_inputs: The batched input images and prompts.
 68            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 69            multimask_output: Whether to predict mutiple or just a single mask.
 70
 71        Returns:
 72            The predicted segmentation masks and iou values.
 73        """
 74        outputs = []
 75        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 76            if "point_coords" in image_record:
 77                points = (
 78                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 79                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 80                )
 81            else:
 82                points = None
 83
 84            if "boxes" in image_record:
 85                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 86            else:
 87                boxes = None
 88
 89            if "mask_inputs" in image_record:
 90                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 91            else:
 92                masks = None
 93
 94            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
 95                points=points,
 96                boxes=boxes,
 97                masks=masks,
 98            )
 99
100            low_res_masks, iou_predictions = self.sam.mask_decoder(
101                image_embeddings=curr_embedding.unsqueeze(0),
102                image_pe=self.sam.prompt_encoder.get_dense_pe(),
103                sparse_prompt_embeddings=sparse_embeddings,
104                dense_prompt_embeddings=dense_embeddings,
105                multimask_output=multimask_output,
106            )
107
108            masks = self.sam.postprocess_masks(
109                low_res_masks,
110                input_size=image_record["input_size"],
111                original_size=image_record["original_size"],
112            )
113
114            outputs.append(
115                {
116                    "low_res_masks": low_res_masks,
117                    "masks": masks,
118                    "iou_predictions": iou_predictions
119                }
120            )
121
122        return outputs
class TrainableSAM(torch.nn.modules.module.Module):
 13class TrainableSAM(nn.Module):
 14    """Wrapper to make the SegmentAnything model trainable.
 15
 16    Args:
 17        sam: The SegmentAnything Model.
 18    """
 19    def __init__(self, sam: Sam) -> None:
 20        super().__init__()
 21        self.sam = sam
 22        self.transform = ResizeLongestSide(sam.image_encoder.img_size)
 23
 24    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
 25        """Resize, normalize pixel values and pad to a square input.
 26
 27        Args:
 28            x: The input tensor.
 29
 30        Returns:
 31            The resized, normalized and padded tensor.
 32            The shape of the image after resizing.
 33        """
 34
 35        # Resize longest side to match the image encoder.
 36        x = self.transform.apply_image_torch(x)
 37        input_size = x.shape[-2:]
 38
 39        # Normalize colors
 40        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
 41
 42        # Pad
 43        h, w = x.shape[-2:]
 44        padh = self.sam.image_encoder.img_size - h
 45        padw = self.sam.image_encoder.img_size - w
 46        x = F.pad(x, (0, padw, 0, padh))
 47        return x, input_size
 48
 49    def image_embeddings_oft(self, batched_inputs):
 50        # Compute the input images.
 51        input_images, input_size = self.preprocess(
 52            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
 53        )
 54        # Update the input size for each input in the batch.
 55        for i in range(len(batched_inputs)):
 56            batched_inputs[i]["input_size"] = input_size
 57        # Compute the image embeddings.
 58        image_embeddings = self.sam.image_encoder(input_images)
 59        return image_embeddings, batched_inputs
 60
 61    # batched inputs follow the same syntax as the input to sam.forward
 62    def forward(
 63        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 64    ) -> List[Dict[str, Any]]:
 65        """Forward pass.
 66
 67        Args:
 68            batched_inputs: The batched input images and prompts.
 69            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 70            multimask_output: Whether to predict mutiple or just a single mask.
 71
 72        Returns:
 73            The predicted segmentation masks and iou values.
 74        """
 75        outputs = []
 76        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 77            if "point_coords" in image_record:
 78                points = (
 79                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 80                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 81                )
 82            else:
 83                points = None
 84
 85            if "boxes" in image_record:
 86                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 87            else:
 88                boxes = None
 89
 90            if "mask_inputs" in image_record:
 91                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 92            else:
 93                masks = None
 94
 95            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
 96                points=points,
 97                boxes=boxes,
 98                masks=masks,
 99            )
100
101            low_res_masks, iou_predictions = self.sam.mask_decoder(
102                image_embeddings=curr_embedding.unsqueeze(0),
103                image_pe=self.sam.prompt_encoder.get_dense_pe(),
104                sparse_prompt_embeddings=sparse_embeddings,
105                dense_prompt_embeddings=dense_embeddings,
106                multimask_output=multimask_output,
107            )
108
109            masks = self.sam.postprocess_masks(
110                low_res_masks,
111                input_size=image_record["input_size"],
112                original_size=image_record["original_size"],
113            )
114
115            outputs.append(
116                {
117                    "low_res_masks": low_res_masks,
118                    "masks": masks,
119                    "iou_predictions": iou_predictions
120                }
121            )
122
123        return outputs

Wrapper to make the SegmentAnything model trainable.

Arguments:
  • sam: The SegmentAnything Model.
TrainableSAM(sam: segment_anything.modeling.sam.Sam)
19    def __init__(self, sam: Sam) -> None:
20        super().__init__()
21        self.sam = sam
22        self.transform = ResizeLongestSide(sam.image_encoder.img_size)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
transform
def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
24    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
25        """Resize, normalize pixel values and pad to a square input.
26
27        Args:
28            x: The input tensor.
29
30        Returns:
31            The resized, normalized and padded tensor.
32            The shape of the image after resizing.
33        """
34
35        # Resize longest side to match the image encoder.
36        x = self.transform.apply_image_torch(x)
37        input_size = x.shape[-2:]
38
39        # Normalize colors
40        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
41
42        # Pad
43        h, w = x.shape[-2:]
44        padh = self.sam.image_encoder.img_size - h
45        padw = self.sam.image_encoder.img_size - w
46        x = F.pad(x, (0, padw, 0, padh))
47        return x, input_size

Resize, normalize pixel values and pad to a square input.

Arguments:
  • x: The input tensor.
Returns:

The resized, normalized and padded tensor. The shape of the image after resizing.

def image_embeddings_oft(self, batched_inputs):
49    def image_embeddings_oft(self, batched_inputs):
50        # Compute the input images.
51        input_images, input_size = self.preprocess(
52            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
53        )
54        # Update the input size for each input in the batch.
55        for i in range(len(batched_inputs)):
56            batched_inputs[i]["input_size"] = input_size
57        # Compute the image embeddings.
58        image_embeddings = self.sam.image_encoder(input_images)
59        return image_embeddings, batched_inputs
def forward( self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False) -> List[Dict[str, Any]]:
 62    def forward(
 63        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 64    ) -> List[Dict[str, Any]]:
 65        """Forward pass.
 66
 67        Args:
 68            batched_inputs: The batched input images and prompts.
 69            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 70            multimask_output: Whether to predict mutiple or just a single mask.
 71
 72        Returns:
 73            The predicted segmentation masks and iou values.
 74        """
 75        outputs = []
 76        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 77            if "point_coords" in image_record:
 78                points = (
 79                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 80                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 81                )
 82            else:
 83                points = None
 84
 85            if "boxes" in image_record:
 86                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 87            else:
 88                boxes = None
 89
 90            if "mask_inputs" in image_record:
 91                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 92            else:
 93                masks = None
 94
 95            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(
 96                points=points,
 97                boxes=boxes,
 98                masks=masks,
 99            )
100
101            low_res_masks, iou_predictions = self.sam.mask_decoder(
102                image_embeddings=curr_embedding.unsqueeze(0),
103                image_pe=self.sam.prompt_encoder.get_dense_pe(),
104                sparse_prompt_embeddings=sparse_embeddings,
105                dense_prompt_embeddings=dense_embeddings,
106                multimask_output=multimask_output,
107            )
108
109            masks = self.sam.postprocess_masks(
110                low_res_masks,
111                input_size=image_record["input_size"],
112                original_size=image_record["original_size"],
113            )
114
115            outputs.append(
116                {
117                    "low_res_masks": low_res_masks,
118                    "masks": masks,
119                    "iou_predictions": iou_predictions
120                }
121            )
122
123        return outputs

Forward pass.

Arguments:
  • batched_inputs: The batched input images and prompts.
  • image_embeddings: The precompute image embeddings. If not passed then they will be computed.
  • multimask_output: Whether to predict mutiple or just a single mask.
Returns:

The predicted segmentation masks and iou values.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile