micro_sam.models.peft_sam
1import math 2from typing import List, Union, Optional 3 4import torch 5import torch.nn as nn 6 7from segment_anything.modeling import Sam 8 9try: 10 import bitsandbytes as bnb 11 _have_bnb = True 12except ImportError: 13 _have_bnb = False 14 15 16class LoRASurgery(nn.Module): 17 """Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation. 18 19 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 20 21 In SAM, it is implemented as: 22 ```python 23 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 B, N, C = x.shape 25 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 26 q, k, v = qkv.unbind(0) 27 ``` 28 29 Args: 30 rank: The rank of the decomposition matrices for updating weights in each attention layer. 31 block: The chosen attention blocks for implementing LoRA. 32 update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp". 33 """ 34 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 35 super().__init__() 36 # Check whether all values for "update_matrices" are as expected. 37 if set(update_matrices) - set(["q", "k", "v", "mlp"]): 38 raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.") 39 40 self.block = block 41 block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices) 42 43 if "mlp" in update_matrices: 44 block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp) 45 46 def forward(self, x): 47 return x 48 49 50class AttentionLoRA(nn.Module): 51 """Operates on the attention layers only for performing low-rank adaptation. 52 53 Args: 54 rank: The rank of the decomposition matrices for updating weights in each attention layer. 55 block: The chosen attention blocks for implementing LoRA. 56 update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v". 57 """ 58 59 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 60 super().__init__() 61 self.qkv_proj = block 62 self.dim = self.qkv_proj.in_features 63 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 64 self.rank = rank 65 66 # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices. 67 if "q" in update_matrices: 68 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 69 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 70 71 if "v" in update_matrices: 72 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 73 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 74 75 if "k" in update_matrices: 76 self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False) 77 self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False) 78 79 self.reset_parameters() 80 81 block = self 82 83 def reset_parameters(self): 84 if hasattr(self, "w_a_linear_q"): 85 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 86 nn.init.zeros_(self.w_b_linear_q.weight) 87 88 if hasattr(self, "w_a_linear_v"): 89 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 90 nn.init.zeros_(self.w_b_linear_v.weight) 91 92 if hasattr(self, "w_a_linear_k"): 93 nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5)) 94 nn.init.zeros_(self.w_b_linear_k.weight) 95 96 def forward(self, x): 97 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 98 99 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0 100 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0 101 new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0 102 qkv = torch.cat( 103 [ 104 qkv[:, :, :, :self.dim] + new_q, # replacing new q values. 105 qkv[:, :, :, self.dim:-self.dim] + new_k, # replacing new k values. 106 qkv[:, :, :, -self.dim:] + new_v # replacing new v values. 107 ], dim=-1 108 ) 109 110 return qkv 111 112 113class MLPLoRA(nn.Module): 114 """Operates on the feed forward layers for performing low-rank adaptation. 115 116 Args: 117 rank: The rank of the decomposition matrices for updating weights in each attention layer. 118 mlp_layer: The chosen MLP layer for implementing LoRA. 119 """ 120 121 def __init__(self, rank: int, mlp_layer: nn.Module): 122 super().__init__() 123 124 self.mlp_layer = mlp_layer 125 self.rank = rank 126 self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False) 127 self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False) 128 self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False) 129 self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False) 130 self.activation = mlp_layer.act 131 132 self.reset_parameters() 133 134 mlp_layer = self 135 136 def reset_parameters(self): 137 nn.init.kaiming_uniform_(self.w_a_linear_1.weight, a=math.sqrt(5)) 138 nn.init.kaiming_uniform_(self.w_a_linear_2.weight, a=math.sqrt(5)) 139 nn.init.zeros_(self.w_b_linear_1.weight) 140 nn.init.zeros_(self.w_b_linear_2.weight) 141 142 def forward(self, x): 143 x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x)) 144 x = self.activation(x) 145 x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x)) 146 return x 147 148 149class FacTSurgery(nn.Module): 150 """Operates on the attention layers for performing factorized attention. 151 152 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 153 154 Args: 155 rank: The rank of the decomposition matrices for updating weights in each attention layer. 156 block: The chosen attention blocks for implementing fact. 157 dropout: The dropout rate for the factorized attention. 158 """ 159 def __init__( 160 self, 161 rank: int, 162 block: nn.Module, 163 dropout: Optional[float] = 0.1, 164 ): 165 super().__init__() 166 self.qkv_proj = block.attn.qkv 167 self.dim = self.qkv_proj.in_features 168 169 self.q_FacTs = nn.Linear(rank, rank, bias=False) 170 self.v_FacTs = nn.Linear(rank, rank, bias=False) 171 172 self.dropout = dropout 173 if self.dropout is not None: 174 self.dp_q = nn.Dropout(self.dropout) 175 self.dp_v = nn.Dropout(self.dropout) 176 177 self.FacTu = nn.Linear(self.dim, rank, bias=False) 178 self.FacTv = nn.Linear(rank, self.dim, bias=False) 179 180 block.attn.qkv = self 181 182 def forward(self, x): 183 qkv = self.qkv_proj(x) 184 185 new_q = self.q_FacTs(self.FacTu(x)) 186 new_v = self.v_FacTs(self.FacTu(x)) 187 188 if self.dropout is not None: 189 new_q = self.dp_q(new_q) 190 new_v = self.dp_v(new_v) 191 192 new_q = self.FacTv(new_q) 193 new_v = self.FacTv(new_v) 194 195 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 196 qkv = torch.cat( 197 [ 198 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 199 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 200 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 201 ], dim=-1 202 ) 203 204 return qkv 205 206 207class ScaleShiftLayer(nn.Module): 208 def __init__(self, layer, dim): 209 super().__init__() 210 self.layer = layer 211 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 212 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 213 layer = self 214 215 def forward(self, x): 216 x = self.layer(x) 217 assert self.scale.shape == self.shift.shape 218 if x.shape[-1] == self.scale.shape[0]: 219 return x * self.scale + self.shift 220 elif x.shape[1] == self.scale.shape[0]: 221 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 222 else: 223 raise ValueError('Input tensors do not match the shape of the scale factors.') 224 225 226class SSFSurgery(nn.Module): 227 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 228 229 Args: 230 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 231 block: The chosen attention blocks for implementing ssf. 232 """ 233 def __init__(self, rank: int, block: nn.Module): 234 super().__init__() 235 self.block = block 236 237 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 238 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 239 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 240 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 241 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 242 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 243 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 244 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 245 246 # If we get the embedding block, add one ScaleShiftLayer 247 elif hasattr(block, "patch_embed"): 248 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 249 250 def forward(self, x): 251 return x 252 253 254class SelectiveSurgery(nn.Module): 255 """Base class for selectively allowing gradient updates for certain parameters. 256 """ 257 def __init__(self, block: nn.Module): 258 super().__init__() 259 self.block = block 260 261 def allow_gradient_update_for_parameters( 262 self, 263 prefix: Optional[List[str]] = None, 264 suffix: Optional[List[str]] = None, 265 infix: Optional[List[str]] = None, 266 ): 267 """This function decides the parameter attributes to match for allowing gradient updates. 268 269 Args: 270 prefix: Matches the part of parameter name in front. 271 suffix: Matches the part of parameter name at the end. 272 infix: Matches parts of parameter name occuring in between. 273 """ 274 for k, v in self.block.named_parameters(): 275 if prefix is not None and k.startswith(tuple(prefix)): 276 v.requires_grad = True 277 278 if suffix is not None and k.endswith(tuple(suffix)): 279 v.requires_grad = True 280 281 if infix is not None: 282 for per_infix in infix: 283 if k.find(per_infix) != -1: 284 v.requires_grad = True 285 286 def forward(self, x): 287 return x 288 289 290class AdaptFormer(nn.Module): 291 """Adds AdaptFormer Module in place of the MLP Layers 292 293 Args: 294 rank: The rank is not used in this class but kept here for consistency. 295 block: The chosen encoder block for implementing AdaptFormer. 296 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 297 dropout: The dropout rate for the dropout layer between down and up projection layer. 298 projection_size: The size of the projection layer. 299 """ 300 def __init__( 301 self, 302 rank: int, 303 block: nn.Module, 304 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 305 dropout: Optional[float] = None, # Does not have an obvious advantage. 306 projection_size: int = 64, # Stable choice from our preliminary exp. 307 ): 308 super().__init__() 309 310 self.mlp_proj = block.mlp 311 self.n_embd = block.mlp.lin1.in_features 312 313 if alpha == 'learnable_scalar': 314 self.alpha = nn.Parameter(torch.ones(1)) 315 else: 316 self.alpha = alpha 317 318 self.projection_size = projection_size 319 self.dropout = dropout 320 321 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 322 self.non_linear_func = nn.ReLU() 323 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 324 325 block.mlp = self 326 327 if self.dropout is not None: 328 self.dropout_layer = nn.Dropout(self.dropout) 329 330 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 331 nn.init.zeros_(self.up_proj.weight) 332 nn.init.zeros_(self.down_proj.bias) 333 nn.init.zeros_(self.up_proj.bias) 334 335 def forward(self, x): 336 residual = x 337 mlp_output = self.mlp_proj(x) 338 339 down = self.down_proj(x) 340 down = self.non_linear_func(down) 341 342 if self.dropout is not None: 343 down = self.dropout_layer(down) 344 345 up = self.up_proj(down) 346 up = up * self.alpha 347 output = up + residual + mlp_output 348 349 return output 350 351 352class AttentionSurgery(SelectiveSurgery): 353 """Child class for allowing gradient updates for parameters in attention layers.""" 354 355 def __init__(self, block: nn.Module): 356 super().__init__(block=block) 357 # Allow gradient updates for the attention layers in the image encoder. 358 self.allow_gradient_update_for_parameters(prefix=["attn"]) 359 360 361class BiasSurgery(SelectiveSurgery): 362 """Child class for allowing gradient updates for bias parameters.""" 363 364 def __init__(self, block: nn.Module): 365 super().__init__(block=block) 366 # Allow gradient updates for the bias parameters in the image encoder. 367 self.allow_gradient_update_for_parameters(suffix=["bias"]) 368 369 370class LayerNormSurgery(SelectiveSurgery): 371 """Child class for allowing gradient updates in normalization layers.""" 372 373 def __init__(self, block: nn.Module): 374 super().__init__(block=block) 375 # Allow gradient updates for the LayerNorm parameters in the image encoder. 376 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) 377 378 379class ClassicalSurgery(SelectiveSurgery): 380 """Child class for freezing specific blocks.""" 381 382 def __init__(self, block: nn.Module): 383 super().__init__(block=block) 384 self.block = block 385 386 for k, v in self.block.named_parameters(): 387 v.requires_grad = True 388 389 def forward(self, x): 390 return x 391 392 393class PEFT_Sam(nn.Module): 394 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 395 396 Args: 397 model: The Segment Anything model. 398 rank: The rank for low-rank adaptation. 399 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 400 attention_layers_to_update: Which specific layers we apply PEFT methods to. 401 For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32. 402 quantize: Whether to quantize the model for lower precision training. 403 module_kwargs: The additional arguments for the respective PEFT modules. 404 """ 405 406 def __init__( 407 self, 408 model: Sam, 409 rank: Optional[int] = None, 410 peft_module: nn.Module = LoRASurgery, 411 attention_layers_to_update: Optional[List[int]] = None, 412 quantize: bool = False, 413 **module_kwargs 414 ): 415 super().__init__() 416 417 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 418 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 419 420 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 421 "Invalid PEFT module" 422 ) 423 if attention_layers_to_update: 424 self.peft_layers = attention_layers_to_update 425 else: # Applies PEFT to the image encoder by default 426 self.peft_layers = list(range(len(model.image_encoder.blocks))) 427 428 self.peft_module = peft_module 429 self.peft_blocks = [] 430 431 # Whether to quantize the linear layers to 4 bit precision. 432 # NOTE: This is currently supported for CUDA-supported devices only. 433 if quantize: 434 if not _have_bnb: 435 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 436 437 for name, module in model.image_encoder.named_modules(): 438 if isinstance(module, torch.nn.Linear): 439 *parent_path, layer_name = name.split(".") 440 parent_module = model.image_encoder 441 442 for sub_module in parent_path: 443 parent_module = getattr(parent_module, sub_module) 444 445 # Create the new Linear4bit layer 446 linear_q = bnb.nn.Linear4bit( 447 module.in_features, 448 module.out_features, 449 bias=False if module.bias is None else True, 450 ) 451 # Assign weights and bias to the new layer 452 new_weight = bnb.nn.Params4bit( 453 data=module.weight, 454 requires_grad=False, 455 ) 456 linear_q.weight = new_weight 457 if module.bias is not None: 458 linear_q.bias = torch.nn.Parameter(module.bias) 459 460 # Replace the original linear layer with the quantized one 461 setattr(parent_module, layer_name, linear_q) 462 463 # Let's freeze all the pretrained image encoder layers first 464 for param in model.image_encoder.parameters(): 465 param.requires_grad = False 466 467 # Add scale and shift parameters to the patch embedding layers. 468 if issubclass(self.peft_module, SSFSurgery): 469 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 470 471 # If specified, the attention layers to update should match the available blocks. 472 if attention_layers_to_update and ( 473 set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks)))) 474 ): 475 raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.") 476 477 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 478 479 # If we only want specific layers with PEFT instead of all 480 if t_layer_i not in self.peft_layers: 481 continue 482 483 if issubclass(self.peft_module, SelectiveSurgery): 484 self.peft_blocks.append(self.peft_module(block=blk)) 485 else: 486 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 487 488 self.peft_blocks = nn.ModuleList(self.peft_blocks) 489 self.sam = model 490 491 def forward(self, batched_input, multimask_output): 492 return self.sam(batched_input, multimask_output)
17class LoRASurgery(nn.Module): 18 """Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation. 19 20 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 21 22 In SAM, it is implemented as: 23 ```python 24 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 25 B, N, C = x.shape 26 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 27 q, k, v = qkv.unbind(0) 28 ``` 29 30 Args: 31 rank: The rank of the decomposition matrices for updating weights in each attention layer. 32 block: The chosen attention blocks for implementing LoRA. 33 update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp". 34 """ 35 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 36 super().__init__() 37 # Check whether all values for "update_matrices" are as expected. 38 if set(update_matrices) - set(["q", "k", "v", "mlp"]): 39 raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.") 40 41 self.block = block 42 block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices) 43 44 if "mlp" in update_matrices: 45 block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp) 46 47 def forward(self, x): 48 return x
Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing LoRA.
- update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp".
35 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 36 super().__init__() 37 # Check whether all values for "update_matrices" are as expected. 38 if set(update_matrices) - set(["q", "k", "v", "mlp"]): 39 raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.") 40 41 self.block = block 42 block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices) 43 44 if "mlp" in update_matrices: 45 block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
51class AttentionLoRA(nn.Module): 52 """Operates on the attention layers only for performing low-rank adaptation. 53 54 Args: 55 rank: The rank of the decomposition matrices for updating weights in each attention layer. 56 block: The chosen attention blocks for implementing LoRA. 57 update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v". 58 """ 59 60 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 61 super().__init__() 62 self.qkv_proj = block 63 self.dim = self.qkv_proj.in_features 64 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 65 self.rank = rank 66 67 # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices. 68 if "q" in update_matrices: 69 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 70 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 71 72 if "v" in update_matrices: 73 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 74 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 75 76 if "k" in update_matrices: 77 self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False) 78 self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False) 79 80 self.reset_parameters() 81 82 block = self 83 84 def reset_parameters(self): 85 if hasattr(self, "w_a_linear_q"): 86 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 87 nn.init.zeros_(self.w_b_linear_q.weight) 88 89 if hasattr(self, "w_a_linear_v"): 90 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 91 nn.init.zeros_(self.w_b_linear_v.weight) 92 93 if hasattr(self, "w_a_linear_k"): 94 nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5)) 95 nn.init.zeros_(self.w_b_linear_k.weight) 96 97 def forward(self, x): 98 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 99 100 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0 101 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0 102 new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0 103 qkv = torch.cat( 104 [ 105 qkv[:, :, :, :self.dim] + new_q, # replacing new q values. 106 qkv[:, :, :, self.dim:-self.dim] + new_k, # replacing new k values. 107 qkv[:, :, :, -self.dim:] + new_v # replacing new v values. 108 ], dim=-1 109 ) 110 111 return qkv
Operates on the attention layers only for performing low-rank adaptation.
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing LoRA.
- update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v".
60 def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]): 61 super().__init__() 62 self.qkv_proj = block 63 self.dim = self.qkv_proj.in_features 64 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 65 self.rank = rank 66 67 # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices. 68 if "q" in update_matrices: 69 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 70 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 71 72 if "v" in update_matrices: 73 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 74 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 75 76 if "k" in update_matrices: 77 self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False) 78 self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False) 79 80 self.reset_parameters() 81 82 block = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
84 def reset_parameters(self): 85 if hasattr(self, "w_a_linear_q"): 86 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 87 nn.init.zeros_(self.w_b_linear_q.weight) 88 89 if hasattr(self, "w_a_linear_v"): 90 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 91 nn.init.zeros_(self.w_b_linear_v.weight) 92 93 if hasattr(self, "w_a_linear_k"): 94 nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5)) 95 nn.init.zeros_(self.w_b_linear_k.weight)
97 def forward(self, x): 98 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 99 100 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0 101 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0 102 new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0 103 qkv = torch.cat( 104 [ 105 qkv[:, :, :, :self.dim] + new_q, # replacing new q values. 106 qkv[:, :, :, self.dim:-self.dim] + new_k, # replacing new k values. 107 qkv[:, :, :, -self.dim:] + new_v # replacing new v values. 108 ], dim=-1 109 ) 110 111 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
114class MLPLoRA(nn.Module): 115 """Operates on the feed forward layers for performing low-rank adaptation. 116 117 Args: 118 rank: The rank of the decomposition matrices for updating weights in each attention layer. 119 mlp_layer: The chosen MLP layer for implementing LoRA. 120 """ 121 122 def __init__(self, rank: int, mlp_layer: nn.Module): 123 super().__init__() 124 125 self.mlp_layer = mlp_layer 126 self.rank = rank 127 self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False) 128 self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False) 129 self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False) 130 self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False) 131 self.activation = mlp_layer.act 132 133 self.reset_parameters() 134 135 mlp_layer = self 136 137 def reset_parameters(self): 138 nn.init.kaiming_uniform_(self.w_a_linear_1.weight, a=math.sqrt(5)) 139 nn.init.kaiming_uniform_(self.w_a_linear_2.weight, a=math.sqrt(5)) 140 nn.init.zeros_(self.w_b_linear_1.weight) 141 nn.init.zeros_(self.w_b_linear_2.weight) 142 143 def forward(self, x): 144 x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x)) 145 x = self.activation(x) 146 x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x)) 147 return x
Operates on the feed forward layers for performing low-rank adaptation.
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- mlp_layer: The chosen MLP layer for implementing LoRA.
122 def __init__(self, rank: int, mlp_layer: nn.Module): 123 super().__init__() 124 125 self.mlp_layer = mlp_layer 126 self.rank = rank 127 self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False) 128 self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False) 129 self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False) 130 self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False) 131 self.activation = mlp_layer.act 132 133 self.reset_parameters() 134 135 mlp_layer = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
143 def forward(self, x): 144 x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x)) 145 x = self.activation(x) 146 x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x)) 147 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
150class FacTSurgery(nn.Module): 151 """Operates on the attention layers for performing factorized attention. 152 153 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 154 155 Args: 156 rank: The rank of the decomposition matrices for updating weights in each attention layer. 157 block: The chosen attention blocks for implementing fact. 158 dropout: The dropout rate for the factorized attention. 159 """ 160 def __init__( 161 self, 162 rank: int, 163 block: nn.Module, 164 dropout: Optional[float] = 0.1, 165 ): 166 super().__init__() 167 self.qkv_proj = block.attn.qkv 168 self.dim = self.qkv_proj.in_features 169 170 self.q_FacTs = nn.Linear(rank, rank, bias=False) 171 self.v_FacTs = nn.Linear(rank, rank, bias=False) 172 173 self.dropout = dropout 174 if self.dropout is not None: 175 self.dp_q = nn.Dropout(self.dropout) 176 self.dp_v = nn.Dropout(self.dropout) 177 178 self.FacTu = nn.Linear(self.dim, rank, bias=False) 179 self.FacTv = nn.Linear(rank, self.dim, bias=False) 180 181 block.attn.qkv = self 182 183 def forward(self, x): 184 qkv = self.qkv_proj(x) 185 186 new_q = self.q_FacTs(self.FacTu(x)) 187 new_v = self.v_FacTs(self.FacTu(x)) 188 189 if self.dropout is not None: 190 new_q = self.dp_q(new_q) 191 new_v = self.dp_v(new_v) 192 193 new_q = self.FacTv(new_q) 194 new_v = self.FacTv(new_v) 195 196 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 197 qkv = torch.cat( 198 [ 199 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 200 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 201 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 202 ], dim=-1 203 ) 204 205 return qkv
Operates on the attention layers for performing factorized attention.
(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing fact.
- dropout: The dropout rate for the factorized attention.
160 def __init__( 161 self, 162 rank: int, 163 block: nn.Module, 164 dropout: Optional[float] = 0.1, 165 ): 166 super().__init__() 167 self.qkv_proj = block.attn.qkv 168 self.dim = self.qkv_proj.in_features 169 170 self.q_FacTs = nn.Linear(rank, rank, bias=False) 171 self.v_FacTs = nn.Linear(rank, rank, bias=False) 172 173 self.dropout = dropout 174 if self.dropout is not None: 175 self.dp_q = nn.Dropout(self.dropout) 176 self.dp_v = nn.Dropout(self.dropout) 177 178 self.FacTu = nn.Linear(self.dim, rank, bias=False) 179 self.FacTv = nn.Linear(rank, self.dim, bias=False) 180 181 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
183 def forward(self, x): 184 qkv = self.qkv_proj(x) 185 186 new_q = self.q_FacTs(self.FacTu(x)) 187 new_v = self.v_FacTs(self.FacTu(x)) 188 189 if self.dropout is not None: 190 new_q = self.dp_q(new_q) 191 new_v = self.dp_v(new_v) 192 193 new_q = self.FacTv(new_q) 194 new_v = self.FacTv(new_v) 195 196 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 197 qkv = torch.cat( 198 [ 199 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 200 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 201 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 202 ], dim=-1 203 ) 204 205 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
208class ScaleShiftLayer(nn.Module): 209 def __init__(self, layer, dim): 210 super().__init__() 211 self.layer = layer 212 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 213 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 214 layer = self 215 216 def forward(self, x): 217 x = self.layer(x) 218 assert self.scale.shape == self.shift.shape 219 if x.shape[-1] == self.scale.shape[0]: 220 return x * self.scale + self.shift 221 elif x.shape[1] == self.scale.shape[0]: 222 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 223 else: 224 raise ValueError('Input tensors do not match the shape of the scale factors.')
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
209 def __init__(self, layer, dim): 210 super().__init__() 211 self.layer = layer 212 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 213 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 214 layer = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
216 def forward(self, x): 217 x = self.layer(x) 218 assert self.scale.shape == self.shift.shape 219 if x.shape[-1] == self.scale.shape[0]: 220 return x * self.scale + self.shift 221 elif x.shape[1] == self.scale.shape[0]: 222 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 223 else: 224 raise ValueError('Input tensors do not match the shape of the scale factors.')
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
227class SSFSurgery(nn.Module): 228 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 229 230 Args: 231 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 232 block: The chosen attention blocks for implementing ssf. 233 """ 234 def __init__(self, rank: int, block: nn.Module): 235 super().__init__() 236 self.block = block 237 238 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 239 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 240 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 241 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 242 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 243 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 244 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 245 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 246 247 # If we get the embedding block, add one ScaleShiftLayer 248 elif hasattr(block, "patch_embed"): 249 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 250 251 def forward(self, x): 252 return x
Operates on all layers in the transformer block for adding learnable scale and shift parameters.
Arguments:
- rank: This parameter is not used in
SSFSurgery
. This is kept here for consistency. - block: The chosen attention blocks for implementing ssf.
234 def __init__(self, rank: int, block: nn.Module): 235 super().__init__() 236 self.block = block 237 238 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 239 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 240 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 241 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 242 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 243 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 244 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 245 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 246 247 # If we get the embedding block, add one ScaleShiftLayer 248 elif hasattr(block, "patch_embed"): 249 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
255class SelectiveSurgery(nn.Module): 256 """Base class for selectively allowing gradient updates for certain parameters. 257 """ 258 def __init__(self, block: nn.Module): 259 super().__init__() 260 self.block = block 261 262 def allow_gradient_update_for_parameters( 263 self, 264 prefix: Optional[List[str]] = None, 265 suffix: Optional[List[str]] = None, 266 infix: Optional[List[str]] = None, 267 ): 268 """This function decides the parameter attributes to match for allowing gradient updates. 269 270 Args: 271 prefix: Matches the part of parameter name in front. 272 suffix: Matches the part of parameter name at the end. 273 infix: Matches parts of parameter name occuring in between. 274 """ 275 for k, v in self.block.named_parameters(): 276 if prefix is not None and k.startswith(tuple(prefix)): 277 v.requires_grad = True 278 279 if suffix is not None and k.endswith(tuple(suffix)): 280 v.requires_grad = True 281 282 if infix is not None: 283 for per_infix in infix: 284 if k.find(per_infix) != -1: 285 v.requires_grad = True 286 287 def forward(self, x): 288 return x
Base class for selectively allowing gradient updates for certain parameters.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
262 def allow_gradient_update_for_parameters( 263 self, 264 prefix: Optional[List[str]] = None, 265 suffix: Optional[List[str]] = None, 266 infix: Optional[List[str]] = None, 267 ): 268 """This function decides the parameter attributes to match for allowing gradient updates. 269 270 Args: 271 prefix: Matches the part of parameter name in front. 272 suffix: Matches the part of parameter name at the end. 273 infix: Matches parts of parameter name occuring in between. 274 """ 275 for k, v in self.block.named_parameters(): 276 if prefix is not None and k.startswith(tuple(prefix)): 277 v.requires_grad = True 278 279 if suffix is not None and k.endswith(tuple(suffix)): 280 v.requires_grad = True 281 282 if infix is not None: 283 for per_infix in infix: 284 if k.find(per_infix) != -1: 285 v.requires_grad = True
This function decides the parameter attributes to match for allowing gradient updates.
Arguments:
- prefix: Matches the part of parameter name in front.
- suffix: Matches the part of parameter name at the end.
- infix: Matches parts of parameter name occuring in between.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
291class AdaptFormer(nn.Module): 292 """Adds AdaptFormer Module in place of the MLP Layers 293 294 Args: 295 rank: The rank is not used in this class but kept here for consistency. 296 block: The chosen encoder block for implementing AdaptFormer. 297 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 298 dropout: The dropout rate for the dropout layer between down and up projection layer. 299 projection_size: The size of the projection layer. 300 """ 301 def __init__( 302 self, 303 rank: int, 304 block: nn.Module, 305 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 306 dropout: Optional[float] = None, # Does not have an obvious advantage. 307 projection_size: int = 64, # Stable choice from our preliminary exp. 308 ): 309 super().__init__() 310 311 self.mlp_proj = block.mlp 312 self.n_embd = block.mlp.lin1.in_features 313 314 if alpha == 'learnable_scalar': 315 self.alpha = nn.Parameter(torch.ones(1)) 316 else: 317 self.alpha = alpha 318 319 self.projection_size = projection_size 320 self.dropout = dropout 321 322 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 323 self.non_linear_func = nn.ReLU() 324 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 325 326 block.mlp = self 327 328 if self.dropout is not None: 329 self.dropout_layer = nn.Dropout(self.dropout) 330 331 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 332 nn.init.zeros_(self.up_proj.weight) 333 nn.init.zeros_(self.down_proj.bias) 334 nn.init.zeros_(self.up_proj.bias) 335 336 def forward(self, x): 337 residual = x 338 mlp_output = self.mlp_proj(x) 339 340 down = self.down_proj(x) 341 down = self.non_linear_func(down) 342 343 if self.dropout is not None: 344 down = self.dropout_layer(down) 345 346 up = self.up_proj(down) 347 up = up * self.alpha 348 output = up + residual + mlp_output 349 350 return output
Adds AdaptFormer Module in place of the MLP Layers
Arguments:
- rank: The rank is not used in this class but kept here for consistency.
- block: The chosen encoder block for implementing AdaptFormer.
- alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
- dropout: The dropout rate for the dropout layer between down and up projection layer.
- projection_size: The size of the projection layer.
301 def __init__( 302 self, 303 rank: int, 304 block: nn.Module, 305 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 306 dropout: Optional[float] = None, # Does not have an obvious advantage. 307 projection_size: int = 64, # Stable choice from our preliminary exp. 308 ): 309 super().__init__() 310 311 self.mlp_proj = block.mlp 312 self.n_embd = block.mlp.lin1.in_features 313 314 if alpha == 'learnable_scalar': 315 self.alpha = nn.Parameter(torch.ones(1)) 316 else: 317 self.alpha = alpha 318 319 self.projection_size = projection_size 320 self.dropout = dropout 321 322 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 323 self.non_linear_func = nn.ReLU() 324 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 325 326 block.mlp = self 327 328 if self.dropout is not None: 329 self.dropout_layer = nn.Dropout(self.dropout) 330 331 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 332 nn.init.zeros_(self.up_proj.weight) 333 nn.init.zeros_(self.down_proj.bias) 334 nn.init.zeros_(self.up_proj.bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
336 def forward(self, x): 337 residual = x 338 mlp_output = self.mlp_proj(x) 339 340 down = self.down_proj(x) 341 down = self.non_linear_func(down) 342 343 if self.dropout is not None: 344 down = self.dropout_layer(down) 345 346 up = self.up_proj(down) 347 up = up * self.alpha 348 output = up + residual + mlp_output 349 350 return output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
353class AttentionSurgery(SelectiveSurgery): 354 """Child class for allowing gradient updates for parameters in attention layers.""" 355 356 def __init__(self, block: nn.Module): 357 super().__init__(block=block) 358 # Allow gradient updates for the attention layers in the image encoder. 359 self.allow_gradient_update_for_parameters(prefix=["attn"])
Child class for allowing gradient updates for parameters in attention layers.
356 def __init__(self, block: nn.Module): 357 super().__init__(block=block) 358 # Allow gradient updates for the attention layers in the image encoder. 359 self.allow_gradient_update_for_parameters(prefix=["attn"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
362class BiasSurgery(SelectiveSurgery): 363 """Child class for allowing gradient updates for bias parameters.""" 364 365 def __init__(self, block: nn.Module): 366 super().__init__(block=block) 367 # Allow gradient updates for the bias parameters in the image encoder. 368 self.allow_gradient_update_for_parameters(suffix=["bias"])
Child class for allowing gradient updates for bias parameters.
365 def __init__(self, block: nn.Module): 366 super().__init__(block=block) 367 # Allow gradient updates for the bias parameters in the image encoder. 368 self.allow_gradient_update_for_parameters(suffix=["bias"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
371class LayerNormSurgery(SelectiveSurgery): 372 """Child class for allowing gradient updates in normalization layers.""" 373 374 def __init__(self, block: nn.Module): 375 super().__init__(block=block) 376 # Allow gradient updates for the LayerNorm parameters in the image encoder. 377 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Child class for allowing gradient updates in normalization layers.
374 def __init__(self, block: nn.Module): 375 super().__init__(block=block) 376 # Allow gradient updates for the LayerNorm parameters in the image encoder. 377 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
380class ClassicalSurgery(SelectiveSurgery): 381 """Child class for freezing specific blocks.""" 382 383 def __init__(self, block: nn.Module): 384 super().__init__(block=block) 385 self.block = block 386 387 for k, v in self.block.named_parameters(): 388 v.requires_grad = True 389 390 def forward(self, x): 391 return x
Child class for freezing specific blocks.
383 def __init__(self, block: nn.Module): 384 super().__init__(block=block) 385 self.block = block 386 387 for k, v in self.block.named_parameters(): 388 v.requires_grad = True
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
394class PEFT_Sam(nn.Module): 395 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 396 397 Args: 398 model: The Segment Anything model. 399 rank: The rank for low-rank adaptation. 400 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 401 attention_layers_to_update: Which specific layers we apply PEFT methods to. 402 For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32. 403 quantize: Whether to quantize the model for lower precision training. 404 module_kwargs: The additional arguments for the respective PEFT modules. 405 """ 406 407 def __init__( 408 self, 409 model: Sam, 410 rank: Optional[int] = None, 411 peft_module: nn.Module = LoRASurgery, 412 attention_layers_to_update: Optional[List[int]] = None, 413 quantize: bool = False, 414 **module_kwargs 415 ): 416 super().__init__() 417 418 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 419 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 420 421 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 422 "Invalid PEFT module" 423 ) 424 if attention_layers_to_update: 425 self.peft_layers = attention_layers_to_update 426 else: # Applies PEFT to the image encoder by default 427 self.peft_layers = list(range(len(model.image_encoder.blocks))) 428 429 self.peft_module = peft_module 430 self.peft_blocks = [] 431 432 # Whether to quantize the linear layers to 4 bit precision. 433 # NOTE: This is currently supported for CUDA-supported devices only. 434 if quantize: 435 if not _have_bnb: 436 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 437 438 for name, module in model.image_encoder.named_modules(): 439 if isinstance(module, torch.nn.Linear): 440 *parent_path, layer_name = name.split(".") 441 parent_module = model.image_encoder 442 443 for sub_module in parent_path: 444 parent_module = getattr(parent_module, sub_module) 445 446 # Create the new Linear4bit layer 447 linear_q = bnb.nn.Linear4bit( 448 module.in_features, 449 module.out_features, 450 bias=False if module.bias is None else True, 451 ) 452 # Assign weights and bias to the new layer 453 new_weight = bnb.nn.Params4bit( 454 data=module.weight, 455 requires_grad=False, 456 ) 457 linear_q.weight = new_weight 458 if module.bias is not None: 459 linear_q.bias = torch.nn.Parameter(module.bias) 460 461 # Replace the original linear layer with the quantized one 462 setattr(parent_module, layer_name, linear_q) 463 464 # Let's freeze all the pretrained image encoder layers first 465 for param in model.image_encoder.parameters(): 466 param.requires_grad = False 467 468 # Add scale and shift parameters to the patch embedding layers. 469 if issubclass(self.peft_module, SSFSurgery): 470 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 471 472 # If specified, the attention layers to update should match the available blocks. 473 if attention_layers_to_update and ( 474 set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks)))) 475 ): 476 raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.") 477 478 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 479 480 # If we only want specific layers with PEFT instead of all 481 if t_layer_i not in self.peft_layers: 482 continue 483 484 if issubclass(self.peft_module, SelectiveSurgery): 485 self.peft_blocks.append(self.peft_module(block=blk)) 486 else: 487 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 488 489 self.peft_blocks = nn.ModuleList(self.peft_blocks) 490 self.sam = model 491 492 def forward(self, batched_input, multimask_output): 493 return self.sam(batched_input, multimask_output)
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Arguments:
- model: The Segment Anything model.
- rank: The rank for low-rank adaptation.
- peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
- attention_layers_to_update: Which specific layers we apply PEFT methods to. For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32.
- quantize: Whether to quantize the model for lower precision training.
- module_kwargs: The additional arguments for the respective PEFT modules.
407 def __init__( 408 self, 409 model: Sam, 410 rank: Optional[int] = None, 411 peft_module: nn.Module = LoRASurgery, 412 attention_layers_to_update: Optional[List[int]] = None, 413 quantize: bool = False, 414 **module_kwargs 415 ): 416 super().__init__() 417 418 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 419 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 420 421 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 422 "Invalid PEFT module" 423 ) 424 if attention_layers_to_update: 425 self.peft_layers = attention_layers_to_update 426 else: # Applies PEFT to the image encoder by default 427 self.peft_layers = list(range(len(model.image_encoder.blocks))) 428 429 self.peft_module = peft_module 430 self.peft_blocks = [] 431 432 # Whether to quantize the linear layers to 4 bit precision. 433 # NOTE: This is currently supported for CUDA-supported devices only. 434 if quantize: 435 if not _have_bnb: 436 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 437 438 for name, module in model.image_encoder.named_modules(): 439 if isinstance(module, torch.nn.Linear): 440 *parent_path, layer_name = name.split(".") 441 parent_module = model.image_encoder 442 443 for sub_module in parent_path: 444 parent_module = getattr(parent_module, sub_module) 445 446 # Create the new Linear4bit layer 447 linear_q = bnb.nn.Linear4bit( 448 module.in_features, 449 module.out_features, 450 bias=False if module.bias is None else True, 451 ) 452 # Assign weights and bias to the new layer 453 new_weight = bnb.nn.Params4bit( 454 data=module.weight, 455 requires_grad=False, 456 ) 457 linear_q.weight = new_weight 458 if module.bias is not None: 459 linear_q.bias = torch.nn.Parameter(module.bias) 460 461 # Replace the original linear layer with the quantized one 462 setattr(parent_module, layer_name, linear_q) 463 464 # Let's freeze all the pretrained image encoder layers first 465 for param in model.image_encoder.parameters(): 466 param.requires_grad = False 467 468 # Add scale and shift parameters to the patch embedding layers. 469 if issubclass(self.peft_module, SSFSurgery): 470 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 471 472 # If specified, the attention layers to update should match the available blocks. 473 if attention_layers_to_update and ( 474 set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks)))) 475 ): 476 raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.") 477 478 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 479 480 # If we only want specific layers with PEFT instead of all 481 if t_layer_i not in self.peft_layers: 482 continue 483 484 if issubclass(self.peft_module, SelectiveSurgery): 485 self.peft_blocks.append(self.peft_module(block=blk)) 486 else: 487 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 488 489 self.peft_blocks = nn.ModuleList(self.peft_blocks) 490 self.sam = model
Initialize internal Module state, shared by both nn.Module and ScriptModule.
492 def forward(self, batched_input, multimask_output): 493 return self.sam(batched_input, multimask_output)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile