micro_sam.training.trainable_sam
1from typing import Any, Dict, List, Tuple 2 3import torch 4from torch import nn 5from torch.nn import functional as F 6 7from segment_anything.modeling import Sam 8from segment_anything.utils.transforms import ResizeLongestSide 9 10 11# simple wrapper around SAM in order to keep things trainable 12class TrainableSAM(nn.Module): 13 """Wrapper to make the SegmentAnything model trainable. 14 15 Args: 16 sam: The SegmentAnything Model. 17 """ 18 19 def __init__(self, sam: Sam) -> None: 20 super().__init__() 21 self.sam = sam 22 self.transform = ResizeLongestSide(sam.image_encoder.img_size) 23 24 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 25 """Resize, normalize pixel values and pad to a square input. 26 27 Args: 28 x: The input tensor. 29 30 Returns: 31 The resized, normalized and padded tensor. 32 The shape of the image after resizing. 33 """ 34 35 # Resize longest side to match the image encoder. 36 x = self.transform.apply_image_torch(x) 37 input_size = x.shape[-2:] 38 39 # Normalize colors 40 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 41 42 # Pad 43 h, w = x.shape[-2:] 44 padh = self.sam.image_encoder.img_size - h 45 padw = self.sam.image_encoder.img_size - w 46 x = F.pad(x, (0, padw, 0, padh)) 47 return x, input_size 48 49 def image_embeddings_oft(self, batched_inputs): 50 # Compute the input images. 51 input_images, input_size = self.preprocess( 52 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 53 ) 54 # Update the input size for each input in the batch. 55 for i in range(len(batched_inputs)): 56 batched_inputs[i]["input_size"] = input_size 57 # Compute the image embeddings. 58 image_embeddings = self.sam.image_encoder(input_images) 59 60 return image_embeddings, batched_inputs 61 62 # batched inputs follow the same syntax as the input to sam.forward 63 def forward( 64 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 65 ) -> List[Dict[str, Any]]: 66 """Forward pass. 67 68 Args: 69 batched_inputs: The batched input images and prompts. 70 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 71 multimask_output: Whether to predict mutiple or just a single mask. 72 73 Returns: 74 The predicted segmentation masks and iou values. 75 """ 76 outputs = [] 77 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 78 if "point_coords" in image_record: 79 points = ( 80 image_record["point_coords"].to(self.sam.device, non_blocking=True), 81 image_record["point_labels"].to(self.sam.device, non_blocking=True) 82 ) 83 else: 84 points = None 85 86 if "boxes" in image_record: 87 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 88 else: 89 boxes = None 90 91 if "mask_inputs" in image_record: 92 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 93 else: 94 masks = None 95 96 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks) 97 98 low_res_masks, iou_predictions = self.sam.mask_decoder( 99 image_embeddings=curr_embedding.unsqueeze(0), 100 image_pe=self.sam.prompt_encoder.get_dense_pe(), 101 sparse_prompt_embeddings=sparse_embeddings, 102 dense_prompt_embeddings=dense_embeddings, 103 multimask_output=multimask_output, 104 ) 105 106 masks = self.sam.postprocess_masks( 107 masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"], 108 ) 109 110 outputs.append( 111 {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions} 112 ) 113 114 return outputs
class
TrainableSAM(torch.nn.modules.module.Module):
13class TrainableSAM(nn.Module): 14 """Wrapper to make the SegmentAnything model trainable. 15 16 Args: 17 sam: The SegmentAnything Model. 18 """ 19 20 def __init__(self, sam: Sam) -> None: 21 super().__init__() 22 self.sam = sam 23 self.transform = ResizeLongestSide(sam.image_encoder.img_size) 24 25 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 26 """Resize, normalize pixel values and pad to a square input. 27 28 Args: 29 x: The input tensor. 30 31 Returns: 32 The resized, normalized and padded tensor. 33 The shape of the image after resizing. 34 """ 35 36 # Resize longest side to match the image encoder. 37 x = self.transform.apply_image_torch(x) 38 input_size = x.shape[-2:] 39 40 # Normalize colors 41 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 42 43 # Pad 44 h, w = x.shape[-2:] 45 padh = self.sam.image_encoder.img_size - h 46 padw = self.sam.image_encoder.img_size - w 47 x = F.pad(x, (0, padw, 0, padh)) 48 return x, input_size 49 50 def image_embeddings_oft(self, batched_inputs): 51 # Compute the input images. 52 input_images, input_size = self.preprocess( 53 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 54 ) 55 # Update the input size for each input in the batch. 56 for i in range(len(batched_inputs)): 57 batched_inputs[i]["input_size"] = input_size 58 # Compute the image embeddings. 59 image_embeddings = self.sam.image_encoder(input_images) 60 61 return image_embeddings, batched_inputs 62 63 # batched inputs follow the same syntax as the input to sam.forward 64 def forward( 65 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 66 ) -> List[Dict[str, Any]]: 67 """Forward pass. 68 69 Args: 70 batched_inputs: The batched input images and prompts. 71 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 72 multimask_output: Whether to predict mutiple or just a single mask. 73 74 Returns: 75 The predicted segmentation masks and iou values. 76 """ 77 outputs = [] 78 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 79 if "point_coords" in image_record: 80 points = ( 81 image_record["point_coords"].to(self.sam.device, non_blocking=True), 82 image_record["point_labels"].to(self.sam.device, non_blocking=True) 83 ) 84 else: 85 points = None 86 87 if "boxes" in image_record: 88 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 89 else: 90 boxes = None 91 92 if "mask_inputs" in image_record: 93 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 94 else: 95 masks = None 96 97 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks) 98 99 low_res_masks, iou_predictions = self.sam.mask_decoder( 100 image_embeddings=curr_embedding.unsqueeze(0), 101 image_pe=self.sam.prompt_encoder.get_dense_pe(), 102 sparse_prompt_embeddings=sparse_embeddings, 103 dense_prompt_embeddings=dense_embeddings, 104 multimask_output=multimask_output, 105 ) 106 107 masks = self.sam.postprocess_masks( 108 masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"], 109 ) 110 111 outputs.append( 112 {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions} 113 ) 114 115 return outputs
Wrapper to make the SegmentAnything model trainable.
Arguments:
- sam: The SegmentAnything Model.
TrainableSAM(sam: segment_anything.modeling.sam.Sam)
20 def __init__(self, sam: Sam) -> None: 21 super().__init__() 22 self.sam = sam 23 self.transform = ResizeLongestSide(sam.image_encoder.img_size)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
25 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 26 """Resize, normalize pixel values and pad to a square input. 27 28 Args: 29 x: The input tensor. 30 31 Returns: 32 The resized, normalized and padded tensor. 33 The shape of the image after resizing. 34 """ 35 36 # Resize longest side to match the image encoder. 37 x = self.transform.apply_image_torch(x) 38 input_size = x.shape[-2:] 39 40 # Normalize colors 41 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 42 43 # Pad 44 h, w = x.shape[-2:] 45 padh = self.sam.image_encoder.img_size - h 46 padw = self.sam.image_encoder.img_size - w 47 x = F.pad(x, (0, padw, 0, padh)) 48 return x, input_size
Resize, normalize pixel values and pad to a square input.
Arguments:
- x: The input tensor.
Returns:
The resized, normalized and padded tensor. The shape of the image after resizing.
def
image_embeddings_oft(self, batched_inputs):
50 def image_embeddings_oft(self, batched_inputs): 51 # Compute the input images. 52 input_images, input_size = self.preprocess( 53 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 54 ) 55 # Update the input size for each input in the batch. 56 for i in range(len(batched_inputs)): 57 batched_inputs[i]["input_size"] = input_size 58 # Compute the image embeddings. 59 image_embeddings = self.sam.image_encoder(input_images) 60 61 return image_embeddings, batched_inputs
def
forward( self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False) -> List[Dict[str, Any]]:
64 def forward( 65 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 66 ) -> List[Dict[str, Any]]: 67 """Forward pass. 68 69 Args: 70 batched_inputs: The batched input images and prompts. 71 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 72 multimask_output: Whether to predict mutiple or just a single mask. 73 74 Returns: 75 The predicted segmentation masks and iou values. 76 """ 77 outputs = [] 78 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 79 if "point_coords" in image_record: 80 points = ( 81 image_record["point_coords"].to(self.sam.device, non_blocking=True), 82 image_record["point_labels"].to(self.sam.device, non_blocking=True) 83 ) 84 else: 85 points = None 86 87 if "boxes" in image_record: 88 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 89 else: 90 boxes = None 91 92 if "mask_inputs" in image_record: 93 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 94 else: 95 masks = None 96 97 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder(points=points, boxes=boxes, masks=masks) 98 99 low_res_masks, iou_predictions = self.sam.mask_decoder( 100 image_embeddings=curr_embedding.unsqueeze(0), 101 image_pe=self.sam.prompt_encoder.get_dense_pe(), 102 sparse_prompt_embeddings=sparse_embeddings, 103 dense_prompt_embeddings=dense_embeddings, 104 multimask_output=multimask_output, 105 ) 106 107 masks = self.sam.postprocess_masks( 108 masks=low_res_masks, input_size=image_record["input_size"], original_size=image_record["original_size"], 109 ) 110 111 outputs.append( 112 {"low_res_masks": low_res_masks, "masks": masks, "iou_predictions": iou_predictions} 113 ) 114 115 return outputs
Forward pass.
Arguments:
- batched_inputs: The batched input images and prompts.
- image_embeddings: The precompute image embeddings. If not passed then they will be computed.
- multimask_output: Whether to predict mutiple or just a single mask.
Returns:
The predicted segmentation masks and iou values.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile