micro_sam.training.trainable_sam
1from typing import Any, Dict, List, Tuple 2 3import torch 4from torch import nn 5from torch.nn import functional as F 6 7from segment_anything.modeling import Sam 8from segment_anything.utils.transforms import ResizeLongestSide 9 10 11# simple wrapper around SAM in order to keep things trainable 12class TrainableSAM(nn.Module): 13 """Wrapper to make the SegmentAnything model trainable. 14 15 Args: 16 sam: The SegmentAnything Model. 17 """ 18 def __init__(self, sam: Sam) -> None: 19 super().__init__() 20 self.sam = sam 21 self.transform = ResizeLongestSide(sam.image_encoder.img_size) 22 23 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 24 """Resize, normalize pixel values and pad to a square input. 25 26 Args: 27 x: The input tensor. 28 29 Returns: 30 The resized, normalized and padded tensor. 31 The shape of the image after resizing. 32 """ 33 34 # Resize longest side to match the image encoder. 35 x = self.transform.apply_image_torch(x) 36 input_size = x.shape[-2:] 37 38 # Normalize colors 39 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 40 41 # Pad 42 h, w = x.shape[-2:] 43 padh = self.sam.image_encoder.img_size - h 44 padw = self.sam.image_encoder.img_size - w 45 x = F.pad(x, (0, padw, 0, padh)) 46 return x, input_size 47 48 def image_embeddings_oft(self, batched_inputs): 49 # Compute the input images. 50 input_images, input_size = self.preprocess( 51 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 52 ) 53 # Update the input size for each input in the batch. 54 for i in range(len(batched_inputs)): 55 batched_inputs[i]["input_size"] = input_size 56 # Compute the image embeddings. 57 image_embeddings = self.sam.image_encoder(input_images) 58 return image_embeddings, batched_inputs 59 60 # batched inputs follow the same syntax as the input to sam.forward 61 def forward( 62 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 63 ) -> List[Dict[str, Any]]: 64 """Forward pass. 65 66 Args: 67 batched_inputs: The batched input images and prompts. 68 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 69 multimask_output: Whether to predict mutiple or just a single mask. 70 71 Returns: 72 The predicted segmentation masks and iou values. 73 """ 74 outputs = [] 75 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 76 if "point_coords" in image_record: 77 points = ( 78 image_record["point_coords"].to(self.sam.device, non_blocking=True), 79 image_record["point_labels"].to(self.sam.device, non_blocking=True) 80 ) 81 else: 82 points = None 83 84 if "boxes" in image_record: 85 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 86 else: 87 boxes = None 88 89 if "mask_inputs" in image_record: 90 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 91 else: 92 masks = None 93 94 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( 95 points=points, 96 boxes=boxes, 97 masks=masks, 98 ) 99 100 low_res_masks, iou_predictions = self.sam.mask_decoder( 101 image_embeddings=curr_embedding.unsqueeze(0), 102 image_pe=self.sam.prompt_encoder.get_dense_pe(), 103 sparse_prompt_embeddings=sparse_embeddings, 104 dense_prompt_embeddings=dense_embeddings, 105 multimask_output=multimask_output, 106 ) 107 108 masks = self.sam.postprocess_masks( 109 low_res_masks, 110 input_size=image_record["input_size"], 111 original_size=image_record["original_size"], 112 ) 113 114 outputs.append( 115 { 116 "low_res_masks": low_res_masks, 117 "masks": masks, 118 "iou_predictions": iou_predictions 119 } 120 ) 121 122 return outputs
class
TrainableSAM(torch.nn.modules.module.Module):
13class TrainableSAM(nn.Module): 14 """Wrapper to make the SegmentAnything model trainable. 15 16 Args: 17 sam: The SegmentAnything Model. 18 """ 19 def __init__(self, sam: Sam) -> None: 20 super().__init__() 21 self.sam = sam 22 self.transform = ResizeLongestSide(sam.image_encoder.img_size) 23 24 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 25 """Resize, normalize pixel values and pad to a square input. 26 27 Args: 28 x: The input tensor. 29 30 Returns: 31 The resized, normalized and padded tensor. 32 The shape of the image after resizing. 33 """ 34 35 # Resize longest side to match the image encoder. 36 x = self.transform.apply_image_torch(x) 37 input_size = x.shape[-2:] 38 39 # Normalize colors 40 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 41 42 # Pad 43 h, w = x.shape[-2:] 44 padh = self.sam.image_encoder.img_size - h 45 padw = self.sam.image_encoder.img_size - w 46 x = F.pad(x, (0, padw, 0, padh)) 47 return x, input_size 48 49 def image_embeddings_oft(self, batched_inputs): 50 # Compute the input images. 51 input_images, input_size = self.preprocess( 52 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 53 ) 54 # Update the input size for each input in the batch. 55 for i in range(len(batched_inputs)): 56 batched_inputs[i]["input_size"] = input_size 57 # Compute the image embeddings. 58 image_embeddings = self.sam.image_encoder(input_images) 59 return image_embeddings, batched_inputs 60 61 # batched inputs follow the same syntax as the input to sam.forward 62 def forward( 63 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 64 ) -> List[Dict[str, Any]]: 65 """Forward pass. 66 67 Args: 68 batched_inputs: The batched input images and prompts. 69 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 70 multimask_output: Whether to predict mutiple or just a single mask. 71 72 Returns: 73 The predicted segmentation masks and iou values. 74 """ 75 outputs = [] 76 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 77 if "point_coords" in image_record: 78 points = ( 79 image_record["point_coords"].to(self.sam.device, non_blocking=True), 80 image_record["point_labels"].to(self.sam.device, non_blocking=True) 81 ) 82 else: 83 points = None 84 85 if "boxes" in image_record: 86 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 87 else: 88 boxes = None 89 90 if "mask_inputs" in image_record: 91 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 92 else: 93 masks = None 94 95 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( 96 points=points, 97 boxes=boxes, 98 masks=masks, 99 ) 100 101 low_res_masks, iou_predictions = self.sam.mask_decoder( 102 image_embeddings=curr_embedding.unsqueeze(0), 103 image_pe=self.sam.prompt_encoder.get_dense_pe(), 104 sparse_prompt_embeddings=sparse_embeddings, 105 dense_prompt_embeddings=dense_embeddings, 106 multimask_output=multimask_output, 107 ) 108 109 masks = self.sam.postprocess_masks( 110 low_res_masks, 111 input_size=image_record["input_size"], 112 original_size=image_record["original_size"], 113 ) 114 115 outputs.append( 116 { 117 "low_res_masks": low_res_masks, 118 "masks": masks, 119 "iou_predictions": iou_predictions 120 } 121 ) 122 123 return outputs
Wrapper to make the SegmentAnything model trainable.
Arguments:
- sam: The SegmentAnything Model.
TrainableSAM(sam: segment_anything.modeling.sam.Sam)
19 def __init__(self, sam: Sam) -> None: 20 super().__init__() 21 self.sam = sam 22 self.transform = ResizeLongestSide(sam.image_encoder.img_size)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
24 def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: 25 """Resize, normalize pixel values and pad to a square input. 26 27 Args: 28 x: The input tensor. 29 30 Returns: 31 The resized, normalized and padded tensor. 32 The shape of the image after resizing. 33 """ 34 35 # Resize longest side to match the image encoder. 36 x = self.transform.apply_image_torch(x) 37 input_size = x.shape[-2:] 38 39 # Normalize colors 40 x = (x - self.sam.pixel_mean.unsqueeze(0)) / self.sam.pixel_std.unsqueeze(0) 41 42 # Pad 43 h, w = x.shape[-2:] 44 padh = self.sam.image_encoder.img_size - h 45 padw = self.sam.image_encoder.img_size - w 46 x = F.pad(x, (0, padw, 0, padh)) 47 return x, input_size
Resize, normalize pixel values and pad to a square input.
Arguments:
- x: The input tensor.
Returns:
The resized, normalized and padded tensor. The shape of the image after resizing.
def
image_embeddings_oft(self, batched_inputs):
49 def image_embeddings_oft(self, batched_inputs): 50 # Compute the input images. 51 input_images, input_size = self.preprocess( 52 torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) 53 ) 54 # Update the input size for each input in the batch. 55 for i in range(len(batched_inputs)): 56 batched_inputs[i]["input_size"] = input_size 57 # Compute the image embeddings. 58 image_embeddings = self.sam.image_encoder(input_images) 59 return image_embeddings, batched_inputs
def
forward( self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False) -> List[Dict[str, Any]]:
62 def forward( 63 self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, 64 ) -> List[Dict[str, Any]]: 65 """Forward pass. 66 67 Args: 68 batched_inputs: The batched input images and prompts. 69 image_embeddings: The precompute image embeddings. If not passed then they will be computed. 70 multimask_output: Whether to predict mutiple or just a single mask. 71 72 Returns: 73 The predicted segmentation masks and iou values. 74 """ 75 outputs = [] 76 for image_record, curr_embedding in zip(batched_inputs, image_embeddings): 77 if "point_coords" in image_record: 78 points = ( 79 image_record["point_coords"].to(self.sam.device, non_blocking=True), 80 image_record["point_labels"].to(self.sam.device, non_blocking=True) 81 ) 82 else: 83 points = None 84 85 if "boxes" in image_record: 86 boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) 87 else: 88 boxes = None 89 90 if "mask_inputs" in image_record: 91 masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) 92 else: 93 masks = None 94 95 sparse_embeddings, dense_embeddings = self.sam.prompt_encoder( 96 points=points, 97 boxes=boxes, 98 masks=masks, 99 ) 100 101 low_res_masks, iou_predictions = self.sam.mask_decoder( 102 image_embeddings=curr_embedding.unsqueeze(0), 103 image_pe=self.sam.prompt_encoder.get_dense_pe(), 104 sparse_prompt_embeddings=sparse_embeddings, 105 dense_prompt_embeddings=dense_embeddings, 106 multimask_output=multimask_output, 107 ) 108 109 masks = self.sam.postprocess_masks( 110 low_res_masks, 111 input_size=image_record["input_size"], 112 original_size=image_record["original_size"], 113 ) 114 115 outputs.append( 116 { 117 "low_res_masks": low_res_masks, 118 "masks": masks, 119 "iou_predictions": iou_predictions 120 } 121 ) 122 123 return outputs
Forward pass.
Arguments:
- batched_inputs: The batched input images and prompts.
- image_embeddings: The precompute image embeddings. If not passed then they will be computed.
- multimask_output: Whether to predict mutiple or just a single mask.
Returns:
The predicted segmentation masks and iou values.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile