micro_sam.models.peft_sam
1import math 2from typing import List, Union, Optional 3 4import torch 5import torch.nn as nn 6 7from segment_anything.modeling import Sam 8 9 10class LoRASurgery(nn.Module): 11 """Operates on the attention layers for performing low-rank adaptation. 12 13 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 14 15 In SAM, it is implemented as: 16 ```python 17 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 B, N, C = x.shape 19 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 20 q, k, v = qkv.unbind(0) 21 ``` 22 23 Args: 24 rank: The rank of the decomposition matrices for updating weights in each attention layer. 25 block: The chosen attention blocks for implementing lora. 26 """ 27 def __init__(self, rank: int, block: nn.Module): 28 super().__init__() 29 self.qkv_proj = block.attn.qkv 30 self.dim = self.qkv_proj.in_features 31 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 32 self.rank = rank 33 34 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 35 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 36 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 37 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 38 39 self.reset_parameters() 40 41 block.attn.qkv = self 42 43 def reset_parameters(self): 44 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 45 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 46 nn.init.zeros_(self.w_b_linear_q.weight) 47 nn.init.zeros_(self.w_b_linear_v.weight) 48 49 def forward(self, x): 50 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 51 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 52 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 53 qkv[:, :, :, :self.dim] += new_q 54 qkv[:, :, :, -self.dim:] += new_v 55 return qkv 56 57 58class FacTSurgery(nn.Module): 59 """Operates on the attention layers for performing factorized attention. 60 61 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 62 63 Args: 64 rank: The rank of the decomposition matrices for updating weights in each attention layer. 65 block: The chosen attention blocks for implementing fact. 66 dropout: The dropout rate for the factorized attention. 67 """ 68 def __init__( 69 self, 70 rank: int, 71 block: nn.Module, 72 dropout: Optional[float] = 0.1, 73 ): 74 super().__init__() 75 self.qkv_proj = block.attn.qkv 76 self.dim = self.qkv_proj.in_features 77 78 self.q_FacTs = nn.Linear(rank, rank, bias=False) 79 self.v_FacTs = nn.Linear(rank, rank, bias=False) 80 81 self.dropout = dropout 82 if self.dropout is not None: 83 self.dp_q = nn.Dropout(self.dropout) 84 self.dp_v = nn.Dropout(self.dropout) 85 86 self.FacTu = nn.Linear(self.dim, rank, bias=False) 87 self.FacTv = nn.Linear(rank, self.dim, bias=False) 88 89 block.attn.qkv = self 90 91 def forward(self, x): 92 qkv = self.qkv_proj(x) 93 94 new_q = self.q_FacTs(self.FacTu(x)) 95 new_v = self.v_FacTs(self.FacTu(x)) 96 97 if self.dropout is not None: 98 new_q = self.dp_q(new_q) 99 new_v = self.dp_v(new_v) 100 101 new_q = self.FacTv(new_q) 102 new_v = self.FacTv(new_v) 103 104 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 105 qkv[:, :, :, : self.dim] += new_q 106 qkv[:, :, :, -self.dim:] += new_v 107 108 return qkv 109 110 111class ScaleShiftLayer(nn.Module): 112 def __init__(self, layer, dim): 113 super().__init__() 114 self.layer = layer 115 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 116 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 117 layer = self 118 119 def forward(self, x): 120 x = self.layer(x) 121 assert self.scale.shape == self.shift.shape 122 if x.shape[-1] == self.scale.shape[0]: 123 return x * self.scale + self.shift 124 elif x.shape[1] == self.scale.shape[0]: 125 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 126 else: 127 raise ValueError('Input tensors do not match the shape of the scale factors.') 128 129 130class SSFSurgery(nn.Module): 131 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 132 133 Args: 134 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 135 block: The chosen attention blocks for implementing ssf. 136 dim: The input dimensions determining the shape of scale and shift parameters. 137 """ 138 def __init__(self, rank: int, block: nn.Module): 139 super().__init__() 140 self.block = block 141 142 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 143 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 144 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 145 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 146 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 147 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 148 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 149 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 150 151 # If we get the embedding block, add one ScaleShiftLayer 152 elif hasattr(block, "patch_embed"): 153 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 154 155 def forward(self, x): 156 return x 157 158 159class SelectiveSurgery(nn.Module): 160 """Base class for selectively allowing gradient updates for certain parameters. 161 """ 162 def __init__(self, block: nn.Module): 163 super().__init__() 164 self.block = block 165 166 def allow_gradient_update_for_parameters( 167 self, 168 prefix: Optional[List[str]] = None, 169 suffix: Optional[List[str]] = None, 170 infix: Optional[List[str]] = None, 171 ): 172 """This function decides the parameter attributes to match for allowing gradient updates. 173 174 Args: 175 prefix: Matches the part of parameter name in front. 176 suffix: Matches the part of parameter name at the end. 177 infix: Matches parts of parameter name occuring in between. 178 """ 179 for k, v in self.block.named_parameters(): 180 if prefix is not None and k.startswith(tuple(prefix)): 181 v.requires_grad = True 182 183 if suffix is not None and k.endswith(tuple(suffix)): 184 v.requires_grad = True 185 186 if infix is not None: 187 for per_infix in infix: 188 if k.find(per_infix) != -1: 189 v.requires_grad = True 190 191 def forward(self, x): 192 return x 193 194 195class AdaptFormer(nn.Module): 196 """Adds AdaptFormer Module in place of the MLP Layers 197 198 Args: 199 rank: The rank is not used in this class but kept here for consistency. 200 block: The chosen encoder block for implementing AdaptFormer. 201 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 202 dropout: The dropout rate for the dropout layer between down and up projection layer. 203 projection_size: The size of the projection layer. 204 """ 205 def __init__( 206 self, 207 rank: int, 208 block: nn.Module, 209 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 210 dropout: Optional[float] = None, # Does not have an obvious advantage. 211 projection_size: int = 64, # Stable choice from our preliminary exp. 212 ): 213 super().__init__() 214 215 self.mlp_proj = block.mlp 216 self.n_embd = block.mlp.lin1.in_features 217 218 if alpha == 'learnable_scalar': 219 self.alpha = nn.Parameter(torch.ones(1)) 220 else: 221 self.alpha = alpha 222 223 self.projection_size = projection_size 224 self.dropout = dropout 225 226 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 227 self.non_linear_func = nn.ReLU() 228 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 229 230 block.mlp = self 231 232 if self.dropout is not None: 233 self.dropout_layer = nn.Dropout(self.dropout) 234 235 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 236 nn.init.zeros_(self.up_proj.weight) 237 nn.init.zeros_(self.down_proj.bias) 238 nn.init.zeros_(self.up_proj.bias) 239 240 def forward(self, x): 241 residual = x 242 mlp_output = self.mlp_proj(x) 243 244 down = self.down_proj(x) 245 down = self.non_linear_func(down) 246 247 if self.dropout is not None: 248 down = self.dropout_layer(down) 249 250 up = self.up_proj(down) 251 up = up * self.alpha 252 output = up + residual + mlp_output 253 254 return output 255 256 257class AttentionSurgery(SelectiveSurgery): 258 """Child class for allowing gradient updates for parameters in attention layers. 259 """ 260 def __init__(self, block: nn.Module): 261 super().__init__(block=block) 262 # Allow gradient updates for the attention layers in the image encoder. 263 self.allow_gradient_update_for_parameters(prefix=["attn"]) 264 265 266class BiasSurgery(SelectiveSurgery): 267 """Child class for allowing gradient updates for bias parameters. 268 """ 269 def __init__(self, block: nn.Module): 270 super().__init__(block=block) 271 # Allow gradient updates for the bias parameters in the image encoder. 272 self.allow_gradient_update_for_parameters(suffix=["bias"]) 273 274 275class LayerNormSurgery(SelectiveSurgery): 276 """Child class for allowing gradient updates in normalization layers. 277 """ 278 def __init__(self, block: nn.Module): 279 super().__init__(block=block) 280 # Allow gradient updates for the LayerNorm parameters in the image encoder. 281 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) 282 283 284class PEFT_Sam(nn.Module): 285 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 286 287 Args: 288 model: The Segment Anything model. 289 rank: The rank for low-rank adaptation. 290 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 291 attention_layers_to_update: Which specific layers we apply PEFT methods to. 292 """ 293 294 def __init__( 295 self, 296 model: Sam, 297 rank: int, 298 peft_module: nn.Module = LoRASurgery, 299 attention_layers_to_update: Union[List[int]] = None, 300 **module_kwargs 301 ): 302 super().__init__() 303 304 assert rank > 0 305 306 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 307 "Invalid PEFT module" 308 ) 309 310 if attention_layers_to_update: 311 self.peft_layers = attention_layers_to_update 312 else: # Applies PEFT to the image encoder by default 313 self.peft_layers = list(range(len(model.image_encoder.blocks))) 314 315 self.peft_module = peft_module 316 self.peft_blocks = [] 317 318 # let's freeze all the pretrained image encoder layers first 319 for param in model.image_encoder.parameters(): 320 param.requires_grad = False 321 322 # Add scale and shift parameters to the patch embedding layers. 323 if issubclass(self.peft_module, SSFSurgery): 324 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 325 326 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 327 # If we only want specific layers with PEFT instead of all 328 if t_layer_i not in self.peft_layers: 329 continue 330 331 if issubclass(self.peft_module, SelectiveSurgery): 332 self.peft_blocks.append(self.peft_module(block=blk)) 333 else: 334 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 335 336 self.peft_blocks = nn.ModuleList(self.peft_blocks) 337 338 self.sam = model 339 340 def forward(self, batched_input, multimask_output): 341 return self.sam(batched_input, multimask_output)
11class LoRASurgery(nn.Module): 12 """Operates on the attention layers for performing low-rank adaptation. 13 14 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 15 16 In SAM, it is implemented as: 17 ```python 18 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 19 B, N, C = x.shape 20 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 21 q, k, v = qkv.unbind(0) 22 ``` 23 24 Args: 25 rank: The rank of the decomposition matrices for updating weights in each attention layer. 26 block: The chosen attention blocks for implementing lora. 27 """ 28 def __init__(self, rank: int, block: nn.Module): 29 super().__init__() 30 self.qkv_proj = block.attn.qkv 31 self.dim = self.qkv_proj.in_features 32 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 33 self.rank = rank 34 35 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 36 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 37 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 38 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 39 40 self.reset_parameters() 41 42 block.attn.qkv = self 43 44 def reset_parameters(self): 45 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 46 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 47 nn.init.zeros_(self.w_b_linear_q.weight) 48 nn.init.zeros_(self.w_b_linear_v.weight) 49 50 def forward(self, x): 51 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 52 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 53 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 54 qkv[:, :, :, :self.dim] += new_q 55 qkv[:, :, :, -self.dim:] += new_v 56 return qkv
Operates on the attention layers for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing lora.
28 def __init__(self, rank: int, block: nn.Module): 29 super().__init__() 30 self.qkv_proj = block.attn.qkv 31 self.dim = self.qkv_proj.in_features 32 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 33 self.rank = rank 34 35 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 36 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 37 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 38 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 39 40 self.reset_parameters() 41 42 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
50 def forward(self, x): 51 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 52 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 53 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 54 qkv[:, :, :, :self.dim] += new_q 55 qkv[:, :, :, -self.dim:] += new_v 56 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
59class FacTSurgery(nn.Module): 60 """Operates on the attention layers for performing factorized attention. 61 62 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 63 64 Args: 65 rank: The rank of the decomposition matrices for updating weights in each attention layer. 66 block: The chosen attention blocks for implementing fact. 67 dropout: The dropout rate for the factorized attention. 68 """ 69 def __init__( 70 self, 71 rank: int, 72 block: nn.Module, 73 dropout: Optional[float] = 0.1, 74 ): 75 super().__init__() 76 self.qkv_proj = block.attn.qkv 77 self.dim = self.qkv_proj.in_features 78 79 self.q_FacTs = nn.Linear(rank, rank, bias=False) 80 self.v_FacTs = nn.Linear(rank, rank, bias=False) 81 82 self.dropout = dropout 83 if self.dropout is not None: 84 self.dp_q = nn.Dropout(self.dropout) 85 self.dp_v = nn.Dropout(self.dropout) 86 87 self.FacTu = nn.Linear(self.dim, rank, bias=False) 88 self.FacTv = nn.Linear(rank, self.dim, bias=False) 89 90 block.attn.qkv = self 91 92 def forward(self, x): 93 qkv = self.qkv_proj(x) 94 95 new_q = self.q_FacTs(self.FacTu(x)) 96 new_v = self.v_FacTs(self.FacTu(x)) 97 98 if self.dropout is not None: 99 new_q = self.dp_q(new_q) 100 new_v = self.dp_v(new_v) 101 102 new_q = self.FacTv(new_q) 103 new_v = self.FacTv(new_v) 104 105 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 106 qkv[:, :, :, : self.dim] += new_q 107 qkv[:, :, :, -self.dim:] += new_v 108 109 return qkv
Operates on the attention layers for performing factorized attention.
(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing fact.
- dropout: The dropout rate for the factorized attention.
69 def __init__( 70 self, 71 rank: int, 72 block: nn.Module, 73 dropout: Optional[float] = 0.1, 74 ): 75 super().__init__() 76 self.qkv_proj = block.attn.qkv 77 self.dim = self.qkv_proj.in_features 78 79 self.q_FacTs = nn.Linear(rank, rank, bias=False) 80 self.v_FacTs = nn.Linear(rank, rank, bias=False) 81 82 self.dropout = dropout 83 if self.dropout is not None: 84 self.dp_q = nn.Dropout(self.dropout) 85 self.dp_v = nn.Dropout(self.dropout) 86 87 self.FacTu = nn.Linear(self.dim, rank, bias=False) 88 self.FacTv = nn.Linear(rank, self.dim, bias=False) 89 90 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
92 def forward(self, x): 93 qkv = self.qkv_proj(x) 94 95 new_q = self.q_FacTs(self.FacTu(x)) 96 new_v = self.v_FacTs(self.FacTu(x)) 97 98 if self.dropout is not None: 99 new_q = self.dp_q(new_q) 100 new_v = self.dp_v(new_v) 101 102 new_q = self.FacTv(new_q) 103 new_v = self.FacTv(new_v) 104 105 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 106 qkv[:, :, :, : self.dim] += new_q 107 qkv[:, :, :, -self.dim:] += new_v 108 109 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
112class ScaleShiftLayer(nn.Module): 113 def __init__(self, layer, dim): 114 super().__init__() 115 self.layer = layer 116 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 117 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 118 layer = self 119 120 def forward(self, x): 121 x = self.layer(x) 122 assert self.scale.shape == self.shift.shape 123 if x.shape[-1] == self.scale.shape[0]: 124 return x * self.scale + self.shift 125 elif x.shape[1] == self.scale.shape[0]: 126 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 127 else: 128 raise ValueError('Input tensors do not match the shape of the scale factors.')
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
113 def __init__(self, layer, dim): 114 super().__init__() 115 self.layer = layer 116 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 117 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 118 layer = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
120 def forward(self, x): 121 x = self.layer(x) 122 assert self.scale.shape == self.shift.shape 123 if x.shape[-1] == self.scale.shape[0]: 124 return x * self.scale + self.shift 125 elif x.shape[1] == self.scale.shape[0]: 126 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 127 else: 128 raise ValueError('Input tensors do not match the shape of the scale factors.')
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
131class SSFSurgery(nn.Module): 132 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 133 134 Args: 135 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 136 block: The chosen attention blocks for implementing ssf. 137 dim: The input dimensions determining the shape of scale and shift parameters. 138 """ 139 def __init__(self, rank: int, block: nn.Module): 140 super().__init__() 141 self.block = block 142 143 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 144 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 145 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 146 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 147 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 148 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 149 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 150 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 151 152 # If we get the embedding block, add one ScaleShiftLayer 153 elif hasattr(block, "patch_embed"): 154 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 155 156 def forward(self, x): 157 return x
Operates on all layers in the transformer block for adding learnable scale and shift parameters.
Arguments:
- rank: This parameter is not used in
SSFSurgery
. This is kept here for consistency. - block: The chosen attention blocks for implementing ssf.
- dim: The input dimensions determining the shape of scale and shift parameters.
139 def __init__(self, rank: int, block: nn.Module): 140 super().__init__() 141 self.block = block 142 143 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 144 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 145 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 146 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 147 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 148 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 149 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 150 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 151 152 # If we get the embedding block, add one ScaleShiftLayer 153 elif hasattr(block, "patch_embed"): 154 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
160class SelectiveSurgery(nn.Module): 161 """Base class for selectively allowing gradient updates for certain parameters. 162 """ 163 def __init__(self, block: nn.Module): 164 super().__init__() 165 self.block = block 166 167 def allow_gradient_update_for_parameters( 168 self, 169 prefix: Optional[List[str]] = None, 170 suffix: Optional[List[str]] = None, 171 infix: Optional[List[str]] = None, 172 ): 173 """This function decides the parameter attributes to match for allowing gradient updates. 174 175 Args: 176 prefix: Matches the part of parameter name in front. 177 suffix: Matches the part of parameter name at the end. 178 infix: Matches parts of parameter name occuring in between. 179 """ 180 for k, v in self.block.named_parameters(): 181 if prefix is not None and k.startswith(tuple(prefix)): 182 v.requires_grad = True 183 184 if suffix is not None and k.endswith(tuple(suffix)): 185 v.requires_grad = True 186 187 if infix is not None: 188 for per_infix in infix: 189 if k.find(per_infix) != -1: 190 v.requires_grad = True 191 192 def forward(self, x): 193 return x
Base class for selectively allowing gradient updates for certain parameters.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
167 def allow_gradient_update_for_parameters( 168 self, 169 prefix: Optional[List[str]] = None, 170 suffix: Optional[List[str]] = None, 171 infix: Optional[List[str]] = None, 172 ): 173 """This function decides the parameter attributes to match for allowing gradient updates. 174 175 Args: 176 prefix: Matches the part of parameter name in front. 177 suffix: Matches the part of parameter name at the end. 178 infix: Matches parts of parameter name occuring in between. 179 """ 180 for k, v in self.block.named_parameters(): 181 if prefix is not None and k.startswith(tuple(prefix)): 182 v.requires_grad = True 183 184 if suffix is not None and k.endswith(tuple(suffix)): 185 v.requires_grad = True 186 187 if infix is not None: 188 for per_infix in infix: 189 if k.find(per_infix) != -1: 190 v.requires_grad = True
This function decides the parameter attributes to match for allowing gradient updates.
Arguments:
- prefix: Matches the part of parameter name in front.
- suffix: Matches the part of parameter name at the end.
- infix: Matches parts of parameter name occuring in between.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
196class AdaptFormer(nn.Module): 197 """Adds AdaptFormer Module in place of the MLP Layers 198 199 Args: 200 rank: The rank is not used in this class but kept here for consistency. 201 block: The chosen encoder block for implementing AdaptFormer. 202 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 203 dropout: The dropout rate for the dropout layer between down and up projection layer. 204 projection_size: The size of the projection layer. 205 """ 206 def __init__( 207 self, 208 rank: int, 209 block: nn.Module, 210 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 211 dropout: Optional[float] = None, # Does not have an obvious advantage. 212 projection_size: int = 64, # Stable choice from our preliminary exp. 213 ): 214 super().__init__() 215 216 self.mlp_proj = block.mlp 217 self.n_embd = block.mlp.lin1.in_features 218 219 if alpha == 'learnable_scalar': 220 self.alpha = nn.Parameter(torch.ones(1)) 221 else: 222 self.alpha = alpha 223 224 self.projection_size = projection_size 225 self.dropout = dropout 226 227 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 228 self.non_linear_func = nn.ReLU() 229 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 230 231 block.mlp = self 232 233 if self.dropout is not None: 234 self.dropout_layer = nn.Dropout(self.dropout) 235 236 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 237 nn.init.zeros_(self.up_proj.weight) 238 nn.init.zeros_(self.down_proj.bias) 239 nn.init.zeros_(self.up_proj.bias) 240 241 def forward(self, x): 242 residual = x 243 mlp_output = self.mlp_proj(x) 244 245 down = self.down_proj(x) 246 down = self.non_linear_func(down) 247 248 if self.dropout is not None: 249 down = self.dropout_layer(down) 250 251 up = self.up_proj(down) 252 up = up * self.alpha 253 output = up + residual + mlp_output 254 255 return output
Adds AdaptFormer Module in place of the MLP Layers
Arguments:
- rank: The rank is not used in this class but kept here for consistency.
- block: The chosen encoder block for implementing AdaptFormer.
- alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
- dropout: The dropout rate for the dropout layer between down and up projection layer.
- projection_size: The size of the projection layer.
206 def __init__( 207 self, 208 rank: int, 209 block: nn.Module, 210 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 211 dropout: Optional[float] = None, # Does not have an obvious advantage. 212 projection_size: int = 64, # Stable choice from our preliminary exp. 213 ): 214 super().__init__() 215 216 self.mlp_proj = block.mlp 217 self.n_embd = block.mlp.lin1.in_features 218 219 if alpha == 'learnable_scalar': 220 self.alpha = nn.Parameter(torch.ones(1)) 221 else: 222 self.alpha = alpha 223 224 self.projection_size = projection_size 225 self.dropout = dropout 226 227 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 228 self.non_linear_func = nn.ReLU() 229 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 230 231 block.mlp = self 232 233 if self.dropout is not None: 234 self.dropout_layer = nn.Dropout(self.dropout) 235 236 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 237 nn.init.zeros_(self.up_proj.weight) 238 nn.init.zeros_(self.down_proj.bias) 239 nn.init.zeros_(self.up_proj.bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
241 def forward(self, x): 242 residual = x 243 mlp_output = self.mlp_proj(x) 244 245 down = self.down_proj(x) 246 down = self.non_linear_func(down) 247 248 if self.dropout is not None: 249 down = self.dropout_layer(down) 250 251 up = self.up_proj(down) 252 up = up * self.alpha 253 output = up + residual + mlp_output 254 255 return output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
258class AttentionSurgery(SelectiveSurgery): 259 """Child class for allowing gradient updates for parameters in attention layers. 260 """ 261 def __init__(self, block: nn.Module): 262 super().__init__(block=block) 263 # Allow gradient updates for the attention layers in the image encoder. 264 self.allow_gradient_update_for_parameters(prefix=["attn"])
Child class for allowing gradient updates for parameters in attention layers.
261 def __init__(self, block: nn.Module): 262 super().__init__(block=block) 263 # Allow gradient updates for the attention layers in the image encoder. 264 self.allow_gradient_update_for_parameters(prefix=["attn"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
267class BiasSurgery(SelectiveSurgery): 268 """Child class for allowing gradient updates for bias parameters. 269 """ 270 def __init__(self, block: nn.Module): 271 super().__init__(block=block) 272 # Allow gradient updates for the bias parameters in the image encoder. 273 self.allow_gradient_update_for_parameters(suffix=["bias"])
Child class for allowing gradient updates for bias parameters.
270 def __init__(self, block: nn.Module): 271 super().__init__(block=block) 272 # Allow gradient updates for the bias parameters in the image encoder. 273 self.allow_gradient_update_for_parameters(suffix=["bias"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
276class LayerNormSurgery(SelectiveSurgery): 277 """Child class for allowing gradient updates in normalization layers. 278 """ 279 def __init__(self, block: nn.Module): 280 super().__init__(block=block) 281 # Allow gradient updates for the LayerNorm parameters in the image encoder. 282 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Child class for allowing gradient updates in normalization layers.
279 def __init__(self, block: nn.Module): 280 super().__init__(block=block) 281 # Allow gradient updates for the LayerNorm parameters in the image encoder. 282 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
285class PEFT_Sam(nn.Module): 286 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 287 288 Args: 289 model: The Segment Anything model. 290 rank: The rank for low-rank adaptation. 291 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 292 attention_layers_to_update: Which specific layers we apply PEFT methods to. 293 """ 294 295 def __init__( 296 self, 297 model: Sam, 298 rank: int, 299 peft_module: nn.Module = LoRASurgery, 300 attention_layers_to_update: Union[List[int]] = None, 301 **module_kwargs 302 ): 303 super().__init__() 304 305 assert rank > 0 306 307 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 308 "Invalid PEFT module" 309 ) 310 311 if attention_layers_to_update: 312 self.peft_layers = attention_layers_to_update 313 else: # Applies PEFT to the image encoder by default 314 self.peft_layers = list(range(len(model.image_encoder.blocks))) 315 316 self.peft_module = peft_module 317 self.peft_blocks = [] 318 319 # let's freeze all the pretrained image encoder layers first 320 for param in model.image_encoder.parameters(): 321 param.requires_grad = False 322 323 # Add scale and shift parameters to the patch embedding layers. 324 if issubclass(self.peft_module, SSFSurgery): 325 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 326 327 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 328 # If we only want specific layers with PEFT instead of all 329 if t_layer_i not in self.peft_layers: 330 continue 331 332 if issubclass(self.peft_module, SelectiveSurgery): 333 self.peft_blocks.append(self.peft_module(block=blk)) 334 else: 335 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 336 337 self.peft_blocks = nn.ModuleList(self.peft_blocks) 338 339 self.sam = model 340 341 def forward(self, batched_input, multimask_output): 342 return self.sam(batched_input, multimask_output)
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Arguments:
- model: The Segment Anything model.
- rank: The rank for low-rank adaptation.
- peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
- attention_layers_to_update: Which specific layers we apply PEFT methods to.
295 def __init__( 296 self, 297 model: Sam, 298 rank: int, 299 peft_module: nn.Module = LoRASurgery, 300 attention_layers_to_update: Union[List[int]] = None, 301 **module_kwargs 302 ): 303 super().__init__() 304 305 assert rank > 0 306 307 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 308 "Invalid PEFT module" 309 ) 310 311 if attention_layers_to_update: 312 self.peft_layers = attention_layers_to_update 313 else: # Applies PEFT to the image encoder by default 314 self.peft_layers = list(range(len(model.image_encoder.blocks))) 315 316 self.peft_module = peft_module 317 self.peft_blocks = [] 318 319 # let's freeze all the pretrained image encoder layers first 320 for param in model.image_encoder.parameters(): 321 param.requires_grad = False 322 323 # Add scale and shift parameters to the patch embedding layers. 324 if issubclass(self.peft_module, SSFSurgery): 325 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 326 327 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 328 # If we only want specific layers with PEFT instead of all 329 if t_layer_i not in self.peft_layers: 330 continue 331 332 if issubclass(self.peft_module, SelectiveSurgery): 333 self.peft_blocks.append(self.peft_module(block=blk)) 334 else: 335 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 336 337 self.peft_blocks = nn.ModuleList(self.peft_blocks) 338 339 self.sam = model
Initialize internal Module state, shared by both nn.Module and ScriptModule.
341 def forward(self, batched_input, multimask_output): 342 return self.sam(batched_input, multimask_output)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile