micro_sam.training.trainable_sam

  1from typing import Any, Dict, List, Tuple
  2
  3import torch
  4from torch import nn
  5from torch.nn import functional as F
  6
  7from segment_anything.modeling import Sam
  8from segment_anything.utils.transforms import ResizeLongestSide
  9
 10
 11# simple wrapper around SAM in order to keep things trainable
 12class TrainableSAM(nn.Module):
 13    """Wrapper to make the SegmentAnything model trainable.
 14
 15    Args:
 16        sam: The SegmentAnything Model.
 17    """
 18
 19    def __init__(self, sam: Sam) -> None:
 20        super().__init__()
 21        self.sam = sam
 22        self.transform = ResizeLongestSide(sam.image_encoder.img_size)
 23
 24    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
 25        """Resize, normalize pixel values and pad to a square input.
 26
 27        Args:
 28            x: The input tensor.
 29
 30        Returns:
 31            The resized, normalized and padded tensor.
 32            The shape of the image after resizing.
 33        """
 34
 35        # Resize longest side to match the image encoder.
 36        x = self.transform.apply_image_torch(x)
 37        input_size = x.shape[-2:]
 38
 39        # Normalize colors
 40        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
 41
 42        # Pad
 43        h, w = x.shape[-2:]
 44        padh = self.sam.image_encoder.img_size - h
 45        padw = self.sam.image_encoder.img_size - w
 46        x = F.pad(x, (0, padw, 0, padh))
 47        return x, input_size
 48
 49    def image_embeddings_oft(self, batched_inputs):
 50        # Compute the input images.
 51        input_images, input_size = self.preprocess(
 52            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
 53        )
 54        # Update the input size for each input in the batch.
 55        for i in range(len(batched_inputs)):
 56            batched_inputs[i]["input_size"] = input_size
 57        # Compute the image embeddings.
 58        image_embeddings = self.sam.image_encoder(input_images)
 59
 60        return image_embeddings, batched_inputs
 61
 62    # batched inputs follow the same syntax as the input to sam.forward
 63    def forward(
 64        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 65    ) -> List[Dict[str, Any]]:
 66        """Forward pass.
 67
 68        Args:
 69            batched_inputs: The batched input images and prompts.
 70            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 71            multimask_output: Whether to predict mutiple or just a single mask.
 72
 73        Returns:
 74            The predicted segmentation masks and iou values.
 75        """
 76        outputs = []
 77        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 78            if "point_coords" in image_record:
 79                points = (
 80                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 81                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 82                )
 83            else:
 84                points = None
 85
 86            if "boxes" in image_record:
 87                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 88            else:
 89                boxes = None
 90
 91            if "mask_inputs" in image_record:
 92                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 93            else:
 94                masks = None
 95
 96            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks)
 97
 98            low_res_masks, iou_predictions = self.sam.mask_decoder(
 99                image_embeddings=curr_embedding.unsqueeze(0),
100                image_pe=self.sam.prompt_encoder.get_dense_pe(),
101                sparse_prompt_embeddings=sparse_embeddings,
102                dense_prompt_embeddings=dense_embeddings,
103                multimask_output=multimask_output,
104            )
105
106            masks = self.sam.postprocess_masks(
107                masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"],
108            )
109
110            outputs.append(
111                {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions}
112            )
113
114        return outputs
class TrainableSAM(torch.nn.modules.module.Module):
 13class TrainableSAM(nn.Module):
 14    """Wrapper to make the SegmentAnything model trainable.
 15
 16    Args:
 17        sam: The SegmentAnything Model.
 18    """
 19
 20    def __init__(self, sam: Sam) -> None:
 21        super().__init__()
 22        self.sam = sam
 23        self.transform = ResizeLongestSide(sam.image_encoder.img_size)
 24
 25    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
 26        """Resize, normalize pixel values and pad to a square input.
 27
 28        Args:
 29            x: The input tensor.
 30
 31        Returns:
 32            The resized, normalized and padded tensor.
 33            The shape of the image after resizing.
 34        """
 35
 36        # Resize longest side to match the image encoder.
 37        x = self.transform.apply_image_torch(x)
 38        input_size = x.shape[-2:]
 39
 40        # Normalize colors
 41        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
 42
 43        # Pad
 44        h, w = x.shape[-2:]
 45        padh = self.sam.image_encoder.img_size - h
 46        padw = self.sam.image_encoder.img_size - w
 47        x = F.pad(x, (0, padw, 0, padh))
 48        return x, input_size
 49
 50    def image_embeddings_oft(self, batched_inputs):
 51        # Compute the input images.
 52        input_images, input_size = self.preprocess(
 53            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
 54        )
 55        # Update the input size for each input in the batch.
 56        for i in range(len(batched_inputs)):
 57            batched_inputs[i]["input_size"] = input_size
 58        # Compute the image embeddings.
 59        image_embeddings = self.sam.image_encoder(input_images)
 60
 61        return image_embeddings, batched_inputs
 62
 63    # batched inputs follow the same syntax as the input to sam.forward
 64    def forward(
 65        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 66    ) -> List[Dict[str, Any]]:
 67        """Forward pass.
 68
 69        Args:
 70            batched_inputs: The batched input images and prompts.
 71            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 72            multimask_output: Whether to predict mutiple or just a single mask.
 73
 74        Returns:
 75            The predicted segmentation masks and iou values.
 76        """
 77        outputs = []
 78        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 79            if "point_coords" in image_record:
 80                points = (
 81                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 82                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 83                )
 84            else:
 85                points = None
 86
 87            if "boxes" in image_record:
 88                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 89            else:
 90                boxes = None
 91
 92            if "mask_inputs" in image_record:
 93                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 94            else:
 95                masks = None
 96
 97            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks)
 98
 99            low_res_masks, iou_predictions = self.sam.mask_decoder(
100                image_embeddings=curr_embedding.unsqueeze(0),
101                image_pe=self.sam.prompt_encoder.get_dense_pe(),
102                sparse_prompt_embeddings=sparse_embeddings,
103                dense_prompt_embeddings=dense_embeddings,
104                multimask_output=multimask_output,
105            )
106
107            masks = self.sam.postprocess_masks(
108                masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"],
109            )
110
111            outputs.append(
112                {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions}
113            )
114
115        return outputs

Wrapper to make the SegmentAnything model trainable.

Arguments:
  • sam: The SegmentAnything Model.
TrainableSAM(sam: segment_anything.modeling.sam.Sam)
20    def __init__(self, sam: Sam) -> None:
21        super().__init__()
22        self.sam = sam
23        self.transform = ResizeLongestSide(sam.image_encoder.img_size)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
transform
def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
25    def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
26        """Resize, normalize pixel values and pad to a square input.
27
28        Args:
29            x: The input tensor.
30
31        Returns:
32            The resized, normalized and padded tensor.
33            The shape of the image after resizing.
34        """
35
36        # Resize longest side to match the image encoder.
37        x = self.transform.apply_image_torch(x)
38        input_size = x.shape[-2:]
39
40        # Normalize colors
41        x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0)
42
43        # Pad
44        h, w = x.shape[-2:]
45        padh = self.sam.image_encoder.img_size - h
46        padw = self.sam.image_encoder.img_size - w
47        x = F.pad(x, (0, padw, 0, padh))
48        return x, input_size

Resize, normalize pixel values and pad to a square input.

Arguments:
  • x: The input tensor.
Returns:

The resized, normalized and padded tensor. The shape of the image after resizing.

def image_embeddings_oft(self, batched_inputs):
50    def image_embeddings_oft(self, batched_inputs):
51        # Compute the input images.
52        input_images, input_size = self.preprocess(
53            torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True)
54        )
55        # Update the input size for each input in the batch.
56        for i in range(len(batched_inputs)):
57            batched_inputs[i]["input_size"] = input_size
58        # Compute the image embeddings.
59        image_embeddings = self.sam.image_encoder(input_images)
60
61        return image_embeddings, batched_inputs
def forward( self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False) -> List[Dict[str, Any]]:
 64    def forward(
 65        self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False,
 66    ) -> List[Dict[str, Any]]:
 67        """Forward pass.
 68
 69        Args:
 70            batched_inputs: The batched input images and prompts.
 71            image_embeddings: The precompute image embeddings. If not passed then they will be computed.
 72            multimask_output: Whether to predict mutiple or just a single mask.
 73
 74        Returns:
 75            The predicted segmentation masks and iou values.
 76        """
 77        outputs = []
 78        for image_record, curr_embedding in zip(batched_inputs, image_embeddings):
 79            if "point_coords" in image_record:
 80                points = (
 81                    image_record["point_coords"].to(self.sam.device, non_blocking=True),
 82                    image_record["point_labels"].to(self.sam.device, non_blocking=True)
 83                )
 84            else:
 85                points = None
 86
 87            if "boxes" in image_record:
 88                boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True)
 89            else:
 90                boxes = None
 91
 92            if "mask_inputs" in image_record:
 93                masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True)
 94            else:
 95                masks = None
 96
 97            sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks)
 98
 99            low_res_masks, iou_predictions = self.sam.mask_decoder(
100                image_embeddings=curr_embedding.unsqueeze(0),
101                image_pe=self.sam.prompt_encoder.get_dense_pe(),
102                sparse_prompt_embeddings=sparse_embeddings,
103                dense_prompt_embeddings=dense_embeddings,
104                multimask_output=multimask_output,
105            )
106
107            masks = self.sam.postprocess_masks(
108                masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"],
109            )
110
111            outputs.append(
112                {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions}
113            )
114
115        return outputs

Forward pass.

Arguments:
  • batched_inputs: The batched input images and prompts.
  • image_embeddings: The precompute image embeddings. If not passed then they will be computed.
  • multimask_output: Whether to predict mutiple or just a single mask.
Returns:

The predicted segmentation masks and iou values.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile