micro_sam.models.peft_sam
1import math 2from typing import List, Union, Optional 3 4import torch 5import torch.nn as nn 6 7from segment_anything.modeling import Sam 8 9try: 10 import bitsandbytes as bnb 11 _have_bnb = True 12except ImportError: 13 _have_bnb = False 14 15 16class LoRASurgery(nn.Module): 17 """Operates on the attention layers for performing low-rank adaptation. 18 19 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 20 21 In SAM, it is implemented as: 22 ```python 23 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 B, N, C = x.shape 25 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 26 q, k, v = qkv.unbind(0) 27 ``` 28 29 Args: 30 rank: The rank of the decomposition matrices for updating weights in each attention layer. 31 block: The chosen attention blocks for implementing LoRA. 32 """ 33 def __init__(self, rank: int, block: nn.Module): 34 super().__init__() 35 self.qkv_proj = block.attn.qkv 36 self.dim = self.qkv_proj.in_features 37 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 38 self.rank = rank 39 40 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 41 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 42 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 43 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 44 45 self.reset_parameters() 46 47 block.attn.qkv = self 48 49 def reset_parameters(self): 50 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 51 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 52 nn.init.zeros_(self.w_b_linear_q.weight) 53 nn.init.zeros_(self.w_b_linear_v.weight) 54 55 def forward(self, x): 56 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 57 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 58 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 59 qkv = torch.cat( 60 [ 61 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 62 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 63 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 64 ], dim=-1 65 ) 66 67 return qkv 68 69 70class FacTSurgery(nn.Module): 71 """Operates on the attention layers for performing factorized attention. 72 73 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 74 75 Args: 76 rank: The rank of the decomposition matrices for updating weights in each attention layer. 77 block: The chosen attention blocks for implementing fact. 78 dropout: The dropout rate for the factorized attention. 79 """ 80 def __init__( 81 self, 82 rank: int, 83 block: nn.Module, 84 dropout: Optional[float] = 0.1, 85 ): 86 super().__init__() 87 self.qkv_proj = block.attn.qkv 88 self.dim = self.qkv_proj.in_features 89 90 self.q_FacTs = nn.Linear(rank, rank, bias=False) 91 self.v_FacTs = nn.Linear(rank, rank, bias=False) 92 93 self.dropout = dropout 94 if self.dropout is not None: 95 self.dp_q = nn.Dropout(self.dropout) 96 self.dp_v = nn.Dropout(self.dropout) 97 98 self.FacTu = nn.Linear(self.dim, rank, bias=False) 99 self.FacTv = nn.Linear(rank, self.dim, bias=False) 100 101 block.attn.qkv = self 102 103 def forward(self, x): 104 qkv = self.qkv_proj(x) 105 106 new_q = self.q_FacTs(self.FacTu(x)) 107 new_v = self.v_FacTs(self.FacTu(x)) 108 109 if self.dropout is not None: 110 new_q = self.dp_q(new_q) 111 new_v = self.dp_v(new_v) 112 113 new_q = self.FacTv(new_q) 114 new_v = self.FacTv(new_v) 115 116 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 117 qkv = torch.cat( 118 [ 119 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 120 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 121 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 122 ], dim=-1 123 ) 124 125 return qkv 126 127 128class ScaleShiftLayer(nn.Module): 129 def __init__(self, layer, dim): 130 super().__init__() 131 self.layer = layer 132 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 133 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 134 layer = self 135 136 def forward(self, x): 137 x = self.layer(x) 138 assert self.scale.shape == self.shift.shape 139 if x.shape[-1] == self.scale.shape[0]: 140 return x * self.scale + self.shift 141 elif x.shape[1] == self.scale.shape[0]: 142 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 143 else: 144 raise ValueError('Input tensors do not match the shape of the scale factors.') 145 146 147class SSFSurgery(nn.Module): 148 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 149 150 Args: 151 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 152 block: The chosen attention blocks for implementing ssf. 153 dim: The input dimensions determining the shape of scale and shift parameters. 154 """ 155 def __init__(self, rank: int, block: nn.Module): 156 super().__init__() 157 self.block = block 158 159 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 160 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 161 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 162 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 163 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 164 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 165 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 166 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 167 168 # If we get the embedding block, add one ScaleShiftLayer 169 elif hasattr(block, "patch_embed"): 170 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 171 172 def forward(self, x): 173 return x 174 175 176class SelectiveSurgery(nn.Module): 177 """Base class for selectively allowing gradient updates for certain parameters. 178 """ 179 def __init__(self, block: nn.Module): 180 super().__init__() 181 self.block = block 182 183 def allow_gradient_update_for_parameters( 184 self, 185 prefix: Optional[List[str]] = None, 186 suffix: Optional[List[str]] = None, 187 infix: Optional[List[str]] = None, 188 ): 189 """This function decides the parameter attributes to match for allowing gradient updates. 190 191 Args: 192 prefix: Matches the part of parameter name in front. 193 suffix: Matches the part of parameter name at the end. 194 infix: Matches parts of parameter name occuring in between. 195 """ 196 for k, v in self.block.named_parameters(): 197 if prefix is not None and k.startswith(tuple(prefix)): 198 v.requires_grad = True 199 200 if suffix is not None and k.endswith(tuple(suffix)): 201 v.requires_grad = True 202 203 if infix is not None: 204 for per_infix in infix: 205 if k.find(per_infix) != -1: 206 v.requires_grad = True 207 208 def forward(self, x): 209 return x 210 211 212class AdaptFormer(nn.Module): 213 """Adds AdaptFormer Module in place of the MLP Layers 214 215 Args: 216 rank: The rank is not used in this class but kept here for consistency. 217 block: The chosen encoder block for implementing AdaptFormer. 218 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 219 dropout: The dropout rate for the dropout layer between down and up projection layer. 220 projection_size: The size of the projection layer. 221 """ 222 def __init__( 223 self, 224 rank: int, 225 block: nn.Module, 226 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 227 dropout: Optional[float] = None, # Does not have an obvious advantage. 228 projection_size: int = 64, # Stable choice from our preliminary exp. 229 ): 230 super().__init__() 231 232 self.mlp_proj = block.mlp 233 self.n_embd = block.mlp.lin1.in_features 234 235 if alpha == 'learnable_scalar': 236 self.alpha = nn.Parameter(torch.ones(1)) 237 else: 238 self.alpha = alpha 239 240 self.projection_size = projection_size 241 self.dropout = dropout 242 243 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 244 self.non_linear_func = nn.ReLU() 245 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 246 247 block.mlp = self 248 249 if self.dropout is not None: 250 self.dropout_layer = nn.Dropout(self.dropout) 251 252 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 253 nn.init.zeros_(self.up_proj.weight) 254 nn.init.zeros_(self.down_proj.bias) 255 nn.init.zeros_(self.up_proj.bias) 256 257 def forward(self, x): 258 residual = x 259 mlp_output = self.mlp_proj(x) 260 261 down = self.down_proj(x) 262 down = self.non_linear_func(down) 263 264 if self.dropout is not None: 265 down = self.dropout_layer(down) 266 267 up = self.up_proj(down) 268 up = up * self.alpha 269 output = up + residual + mlp_output 270 271 return output 272 273 274class AttentionSurgery(SelectiveSurgery): 275 """Child class for allowing gradient updates for parameters in attention layers. 276 """ 277 def __init__(self, block: nn.Module): 278 super().__init__(block=block) 279 # Allow gradient updates for the attention layers in the image encoder. 280 self.allow_gradient_update_for_parameters(prefix=["attn"]) 281 282 283class BiasSurgery(SelectiveSurgery): 284 """Child class for allowing gradient updates for bias parameters. 285 """ 286 def __init__(self, block: nn.Module): 287 super().__init__(block=block) 288 # Allow gradient updates for the bias parameters in the image encoder. 289 self.allow_gradient_update_for_parameters(suffix=["bias"]) 290 291 292class LayerNormSurgery(SelectiveSurgery): 293 """Child class for allowing gradient updates in normalization layers. 294 """ 295 def __init__(self, block: nn.Module): 296 super().__init__(block=block) 297 # Allow gradient updates for the LayerNorm parameters in the image encoder. 298 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) 299 300 301class PEFT_Sam(nn.Module): 302 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 303 304 Args: 305 model: The Segment Anything model. 306 rank: The rank for low-rank adaptation. 307 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 308 attention_layers_to_update: Which specific layers we apply PEFT methods to. 309 quantize: Whether to quantize the model for lower precision training. 310 """ 311 312 def __init__( 313 self, 314 model: Sam, 315 rank: Optional[int] = None, 316 peft_module: nn.Module = LoRASurgery, 317 attention_layers_to_update: Union[List[int]] = None, 318 quantize: bool = False, 319 **module_kwargs 320 ): 321 super().__init__() 322 323 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 324 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 325 326 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 327 "Invalid PEFT module" 328 ) 329 330 if attention_layers_to_update: 331 self.peft_layers = attention_layers_to_update 332 else: # Applies PEFT to the image encoder by default 333 self.peft_layers = list(range(len(model.image_encoder.blocks))) 334 335 self.peft_module = peft_module 336 self.peft_blocks = [] 337 338 # Whether to quantize the linear layers to 4 bit precision. 339 # NOTE: This is currently supported for CUDA-supported devices only. 340 if quantize: 341 if not _have_bnb: 342 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 343 344 for name, module in model.image_encoder.named_modules(): 345 if isinstance(module, torch.nn.Linear): 346 *parent_path, layer_name = name.split(".") 347 parent_module = model.image_encoder 348 349 for sub_module in parent_path: 350 parent_module = getattr(parent_module, sub_module) 351 352 # Create the new Linear4bit layer 353 linear_q = bnb.nn.Linear4bit( 354 module.in_features, 355 module.out_features, 356 bias=False if module.bias is None else True, 357 ) 358 # Assign weights and bias to the new layer 359 new_weight = bnb.nn.Params4bit( 360 data=module.weight, 361 requires_grad=False, 362 ) 363 linear_q.weight = new_weight 364 if module.bias is not None: 365 linear_q.bias = torch.nn.Parameter(module.bias) 366 367 # Replace the original linear layer with the quantized one 368 setattr(parent_module, layer_name, linear_q) 369 370 # Let's freeze all the pretrained image encoder layers first 371 for param in model.image_encoder.parameters(): 372 param.requires_grad = False 373 374 # Add scale and shift parameters to the patch embedding layers. 375 if issubclass(self.peft_module, SSFSurgery): 376 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 377 378 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 379 # If we only want specific layers with PEFT instead of all 380 if t_layer_i not in self.peft_layers: 381 continue 382 383 if issubclass(self.peft_module, SelectiveSurgery): 384 self.peft_blocks.append(self.peft_module(block=blk)) 385 else: 386 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 387 388 self.peft_blocks = nn.ModuleList(self.peft_blocks) 389 390 self.sam = model 391 392 def forward(self, batched_input, multimask_output): 393 return self.sam(batched_input, multimask_output)
17class LoRASurgery(nn.Module): 18 """Operates on the attention layers for performing low-rank adaptation. 19 20 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 21 22 In SAM, it is implemented as: 23 ```python 24 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 25 B, N, C = x.shape 26 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 27 q, k, v = qkv.unbind(0) 28 ``` 29 30 Args: 31 rank: The rank of the decomposition matrices for updating weights in each attention layer. 32 block: The chosen attention blocks for implementing LoRA. 33 """ 34 def __init__(self, rank: int, block: nn.Module): 35 super().__init__() 36 self.qkv_proj = block.attn.qkv 37 self.dim = self.qkv_proj.in_features 38 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 39 self.rank = rank 40 41 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 42 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 43 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 44 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 45 46 self.reset_parameters() 47 48 block.attn.qkv = self 49 50 def reset_parameters(self): 51 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 52 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 53 nn.init.zeros_(self.w_b_linear_q.weight) 54 nn.init.zeros_(self.w_b_linear_v.weight) 55 56 def forward(self, x): 57 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 58 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 59 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 60 qkv = torch.cat( 61 [ 62 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 63 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 64 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 65 ], dim=-1 66 ) 67 68 return qkv
Operates on the attention layers for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing LoRA.
34 def __init__(self, rank: int, block: nn.Module): 35 super().__init__() 36 self.qkv_proj = block.attn.qkv 37 self.dim = self.qkv_proj.in_features 38 self.alpha = 1 # From our experiments, 'alpha' as 1 gives the best performance. 39 self.rank = rank 40 41 self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) 42 self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) 43 self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) 44 self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) 45 46 self.reset_parameters() 47 48 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
56 def forward(self, x): 57 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 58 new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) 59 new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) 60 qkv = torch.cat( 61 [ 62 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 63 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 64 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 65 ], dim=-1 66 ) 67 68 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
71class FacTSurgery(nn.Module): 72 """Operates on the attention layers for performing factorized attention. 73 74 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 75 76 Args: 77 rank: The rank of the decomposition matrices for updating weights in each attention layer. 78 block: The chosen attention blocks for implementing fact. 79 dropout: The dropout rate for the factorized attention. 80 """ 81 def __init__( 82 self, 83 rank: int, 84 block: nn.Module, 85 dropout: Optional[float] = 0.1, 86 ): 87 super().__init__() 88 self.qkv_proj = block.attn.qkv 89 self.dim = self.qkv_proj.in_features 90 91 self.q_FacTs = nn.Linear(rank, rank, bias=False) 92 self.v_FacTs = nn.Linear(rank, rank, bias=False) 93 94 self.dropout = dropout 95 if self.dropout is not None: 96 self.dp_q = nn.Dropout(self.dropout) 97 self.dp_v = nn.Dropout(self.dropout) 98 99 self.FacTu = nn.Linear(self.dim, rank, bias=False) 100 self.FacTv = nn.Linear(rank, self.dim, bias=False) 101 102 block.attn.qkv = self 103 104 def forward(self, x): 105 qkv = self.qkv_proj(x) 106 107 new_q = self.q_FacTs(self.FacTu(x)) 108 new_v = self.v_FacTs(self.FacTu(x)) 109 110 if self.dropout is not None: 111 new_q = self.dp_q(new_q) 112 new_v = self.dp_v(new_v) 113 114 new_q = self.FacTv(new_q) 115 new_v = self.FacTv(new_v) 116 117 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 118 qkv = torch.cat( 119 [ 120 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 121 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 122 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 123 ], dim=-1 124 ) 125 126 return qkv
Operates on the attention layers for performing factorized attention.
(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing fact.
- dropout: The dropout rate for the factorized attention.
81 def __init__( 82 self, 83 rank: int, 84 block: nn.Module, 85 dropout: Optional[float] = 0.1, 86 ): 87 super().__init__() 88 self.qkv_proj = block.attn.qkv 89 self.dim = self.qkv_proj.in_features 90 91 self.q_FacTs = nn.Linear(rank, rank, bias=False) 92 self.v_FacTs = nn.Linear(rank, rank, bias=False) 93 94 self.dropout = dropout 95 if self.dropout is not None: 96 self.dp_q = nn.Dropout(self.dropout) 97 self.dp_v = nn.Dropout(self.dropout) 98 99 self.FacTu = nn.Linear(self.dim, rank, bias=False) 100 self.FacTv = nn.Linear(rank, self.dim, bias=False) 101 102 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
104 def forward(self, x): 105 qkv = self.qkv_proj(x) 106 107 new_q = self.q_FacTs(self.FacTu(x)) 108 new_v = self.v_FacTs(self.FacTu(x)) 109 110 if self.dropout is not None: 111 new_q = self.dp_q(new_q) 112 new_v = self.dp_v(new_v) 113 114 new_q = self.FacTv(new_q) 115 new_v = self.FacTv(new_v) 116 117 # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate. 118 qkv = torch.cat( 119 [ 120 qkv[:, :, :, :self.dim] + new_q, # replacing new q values 121 qkv[:, :, :, self.dim:-self.dim], # leaving the middle part as identical 122 qkv[:, :, :, -self.dim:] + new_v # replacing new v values 123 ], dim=-1 124 ) 125 126 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
129class ScaleShiftLayer(nn.Module): 130 def __init__(self, layer, dim): 131 super().__init__() 132 self.layer = layer 133 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 134 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 135 layer = self 136 137 def forward(self, x): 138 x = self.layer(x) 139 assert self.scale.shape == self.shift.shape 140 if x.shape[-1] == self.scale.shape[0]: 141 return x * self.scale + self.shift 142 elif x.shape[1] == self.scale.shape[0]: 143 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 144 else: 145 raise ValueError('Input tensors do not match the shape of the scale factors.')
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
130 def __init__(self, layer, dim): 131 super().__init__() 132 self.layer = layer 133 self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) 134 self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) 135 layer = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
137 def forward(self, x): 138 x = self.layer(x) 139 assert self.scale.shape == self.shift.shape 140 if x.shape[-1] == self.scale.shape[0]: 141 return x * self.scale + self.shift 142 elif x.shape[1] == self.scale.shape[0]: 143 return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) 144 else: 145 raise ValueError('Input tensors do not match the shape of the scale factors.')
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
148class SSFSurgery(nn.Module): 149 """Operates on all layers in the transformer block for adding learnable scale and shift parameters. 150 151 Args: 152 rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency. 153 block: The chosen attention blocks for implementing ssf. 154 dim: The input dimensions determining the shape of scale and shift parameters. 155 """ 156 def __init__(self, rank: int, block: nn.Module): 157 super().__init__() 158 self.block = block 159 160 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 161 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 162 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 163 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 164 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 165 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 166 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 167 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 168 169 # If we get the embedding block, add one ScaleShiftLayer 170 elif hasattr(block, "patch_embed"): 171 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels) 172 173 def forward(self, x): 174 return x
Operates on all layers in the transformer block for adding learnable scale and shift parameters.
Arguments:
- rank: This parameter is not used in
SSFSurgery
. This is kept here for consistency. - block: The chosen attention blocks for implementing ssf.
- dim: The input dimensions determining the shape of scale and shift parameters.
156 def __init__(self, rank: int, block: nn.Module): 157 super().__init__() 158 self.block = block 159 160 # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer. 161 if hasattr(block, "attn"): # the minimum assumption is to verify the attention layers. 162 block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3) 163 block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features) 164 block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features) 165 block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features) 166 block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0]) 167 block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0]) 168 169 # If we get the embedding block, add one ScaleShiftLayer 170 elif hasattr(block, "patch_embed"): 171 block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
177class SelectiveSurgery(nn.Module): 178 """Base class for selectively allowing gradient updates for certain parameters. 179 """ 180 def __init__(self, block: nn.Module): 181 super().__init__() 182 self.block = block 183 184 def allow_gradient_update_for_parameters( 185 self, 186 prefix: Optional[List[str]] = None, 187 suffix: Optional[List[str]] = None, 188 infix: Optional[List[str]] = None, 189 ): 190 """This function decides the parameter attributes to match for allowing gradient updates. 191 192 Args: 193 prefix: Matches the part of parameter name in front. 194 suffix: Matches the part of parameter name at the end. 195 infix: Matches parts of parameter name occuring in between. 196 """ 197 for k, v in self.block.named_parameters(): 198 if prefix is not None and k.startswith(tuple(prefix)): 199 v.requires_grad = True 200 201 if suffix is not None and k.endswith(tuple(suffix)): 202 v.requires_grad = True 203 204 if infix is not None: 205 for per_infix in infix: 206 if k.find(per_infix) != -1: 207 v.requires_grad = True 208 209 def forward(self, x): 210 return x
Base class for selectively allowing gradient updates for certain parameters.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
184 def allow_gradient_update_for_parameters( 185 self, 186 prefix: Optional[List[str]] = None, 187 suffix: Optional[List[str]] = None, 188 infix: Optional[List[str]] = None, 189 ): 190 """This function decides the parameter attributes to match for allowing gradient updates. 191 192 Args: 193 prefix: Matches the part of parameter name in front. 194 suffix: Matches the part of parameter name at the end. 195 infix: Matches parts of parameter name occuring in between. 196 """ 197 for k, v in self.block.named_parameters(): 198 if prefix is not None and k.startswith(tuple(prefix)): 199 v.requires_grad = True 200 201 if suffix is not None and k.endswith(tuple(suffix)): 202 v.requires_grad = True 203 204 if infix is not None: 205 for per_infix in infix: 206 if k.find(per_infix) != -1: 207 v.requires_grad = True
This function decides the parameter attributes to match for allowing gradient updates.
Arguments:
- prefix: Matches the part of parameter name in front.
- suffix: Matches the part of parameter name at the end.
- infix: Matches parts of parameter name occuring in between.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
213class AdaptFormer(nn.Module): 214 """Adds AdaptFormer Module in place of the MLP Layers 215 216 Args: 217 rank: The rank is not used in this class but kept here for consistency. 218 block: The chosen encoder block for implementing AdaptFormer. 219 alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value. 220 dropout: The dropout rate for the dropout layer between down and up projection layer. 221 projection_size: The size of the projection layer. 222 """ 223 def __init__( 224 self, 225 rank: int, 226 block: nn.Module, 227 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 228 dropout: Optional[float] = None, # Does not have an obvious advantage. 229 projection_size: int = 64, # Stable choice from our preliminary exp. 230 ): 231 super().__init__() 232 233 self.mlp_proj = block.mlp 234 self.n_embd = block.mlp.lin1.in_features 235 236 if alpha == 'learnable_scalar': 237 self.alpha = nn.Parameter(torch.ones(1)) 238 else: 239 self.alpha = alpha 240 241 self.projection_size = projection_size 242 self.dropout = dropout 243 244 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 245 self.non_linear_func = nn.ReLU() 246 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 247 248 block.mlp = self 249 250 if self.dropout is not None: 251 self.dropout_layer = nn.Dropout(self.dropout) 252 253 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 254 nn.init.zeros_(self.up_proj.weight) 255 nn.init.zeros_(self.down_proj.bias) 256 nn.init.zeros_(self.up_proj.bias) 257 258 def forward(self, x): 259 residual = x 260 mlp_output = self.mlp_proj(x) 261 262 down = self.down_proj(x) 263 down = self.non_linear_func(down) 264 265 if self.dropout is not None: 266 down = self.dropout_layer(down) 267 268 up = self.up_proj(down) 269 up = up * self.alpha 270 output = up + residual + mlp_output 271 272 return output
Adds AdaptFormer Module in place of the MLP Layers
Arguments:
- rank: The rank is not used in this class but kept here for consistency.
- block: The chosen encoder block for implementing AdaptFormer.
- alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
- dropout: The dropout rate for the dropout layer between down and up projection layer.
- projection_size: The size of the projection layer.
223 def __init__( 224 self, 225 rank: int, 226 block: nn.Module, 227 alpha: Optional[Union[str, float]] = "learnable_scalar", # Stable choice from our preliminary exp. 228 dropout: Optional[float] = None, # Does not have an obvious advantage. 229 projection_size: int = 64, # Stable choice from our preliminary exp. 230 ): 231 super().__init__() 232 233 self.mlp_proj = block.mlp 234 self.n_embd = block.mlp.lin1.in_features 235 236 if alpha == 'learnable_scalar': 237 self.alpha = nn.Parameter(torch.ones(1)) 238 else: 239 self.alpha = alpha 240 241 self.projection_size = projection_size 242 self.dropout = dropout 243 244 self.down_proj = nn.Linear(self.n_embd, self.projection_size) 245 self.non_linear_func = nn.ReLU() 246 self.up_proj = nn.Linear(self.projection_size, self.n_embd) 247 248 block.mlp = self 249 250 if self.dropout is not None: 251 self.dropout_layer = nn.Dropout(self.dropout) 252 253 nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 254 nn.init.zeros_(self.up_proj.weight) 255 nn.init.zeros_(self.down_proj.bias) 256 nn.init.zeros_(self.up_proj.bias)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
258 def forward(self, x): 259 residual = x 260 mlp_output = self.mlp_proj(x) 261 262 down = self.down_proj(x) 263 down = self.non_linear_func(down) 264 265 if self.dropout is not None: 266 down = self.dropout_layer(down) 267 268 up = self.up_proj(down) 269 up = up * self.alpha 270 output = up + residual + mlp_output 271 272 return output
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
275class AttentionSurgery(SelectiveSurgery): 276 """Child class for allowing gradient updates for parameters in attention layers. 277 """ 278 def __init__(self, block: nn.Module): 279 super().__init__(block=block) 280 # Allow gradient updates for the attention layers in the image encoder. 281 self.allow_gradient_update_for_parameters(prefix=["attn"])
Child class for allowing gradient updates for parameters in attention layers.
278 def __init__(self, block: nn.Module): 279 super().__init__(block=block) 280 # Allow gradient updates for the attention layers in the image encoder. 281 self.allow_gradient_update_for_parameters(prefix=["attn"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
284class BiasSurgery(SelectiveSurgery): 285 """Child class for allowing gradient updates for bias parameters. 286 """ 287 def __init__(self, block: nn.Module): 288 super().__init__(block=block) 289 # Allow gradient updates for the bias parameters in the image encoder. 290 self.allow_gradient_update_for_parameters(suffix=["bias"])
Child class for allowing gradient updates for bias parameters.
287 def __init__(self, block: nn.Module): 288 super().__init__(block=block) 289 # Allow gradient updates for the bias parameters in the image encoder. 290 self.allow_gradient_update_for_parameters(suffix=["bias"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
293class LayerNormSurgery(SelectiveSurgery): 294 """Child class for allowing gradient updates in normalization layers. 295 """ 296 def __init__(self, block: nn.Module): 297 super().__init__(block=block) 298 # Allow gradient updates for the LayerNorm parameters in the image encoder. 299 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Child class for allowing gradient updates in normalization layers.
296 def __init__(self, block: nn.Module): 297 super().__init__(block=block) 298 # Allow gradient updates for the LayerNorm parameters in the image encoder. 299 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
302class PEFT_Sam(nn.Module): 303 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 304 305 Args: 306 model: The Segment Anything model. 307 rank: The rank for low-rank adaptation. 308 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 309 attention_layers_to_update: Which specific layers we apply PEFT methods to. 310 quantize: Whether to quantize the model for lower precision training. 311 """ 312 313 def __init__( 314 self, 315 model: Sam, 316 rank: Optional[int] = None, 317 peft_module: nn.Module = LoRASurgery, 318 attention_layers_to_update: Union[List[int]] = None, 319 quantize: bool = False, 320 **module_kwargs 321 ): 322 super().__init__() 323 324 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 325 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 326 327 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 328 "Invalid PEFT module" 329 ) 330 331 if attention_layers_to_update: 332 self.peft_layers = attention_layers_to_update 333 else: # Applies PEFT to the image encoder by default 334 self.peft_layers = list(range(len(model.image_encoder.blocks))) 335 336 self.peft_module = peft_module 337 self.peft_blocks = [] 338 339 # Whether to quantize the linear layers to 4 bit precision. 340 # NOTE: This is currently supported for CUDA-supported devices only. 341 if quantize: 342 if not _have_bnb: 343 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 344 345 for name, module in model.image_encoder.named_modules(): 346 if isinstance(module, torch.nn.Linear): 347 *parent_path, layer_name = name.split(".") 348 parent_module = model.image_encoder 349 350 for sub_module in parent_path: 351 parent_module = getattr(parent_module, sub_module) 352 353 # Create the new Linear4bit layer 354 linear_q = bnb.nn.Linear4bit( 355 module.in_features, 356 module.out_features, 357 bias=False if module.bias is None else True, 358 ) 359 # Assign weights and bias to the new layer 360 new_weight = bnb.nn.Params4bit( 361 data=module.weight, 362 requires_grad=False, 363 ) 364 linear_q.weight = new_weight 365 if module.bias is not None: 366 linear_q.bias = torch.nn.Parameter(module.bias) 367 368 # Replace the original linear layer with the quantized one 369 setattr(parent_module, layer_name, linear_q) 370 371 # Let's freeze all the pretrained image encoder layers first 372 for param in model.image_encoder.parameters(): 373 param.requires_grad = False 374 375 # Add scale and shift parameters to the patch embedding layers. 376 if issubclass(self.peft_module, SSFSurgery): 377 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 378 379 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 380 # If we only want specific layers with PEFT instead of all 381 if t_layer_i not in self.peft_layers: 382 continue 383 384 if issubclass(self.peft_module, SelectiveSurgery): 385 self.peft_blocks.append(self.peft_module(block=blk)) 386 else: 387 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 388 389 self.peft_blocks = nn.ModuleList(self.peft_blocks) 390 391 self.sam = model 392 393 def forward(self, batched_input, multimask_output): 394 return self.sam(batched_input, multimask_output)
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Arguments:
- model: The Segment Anything model.
- rank: The rank for low-rank adaptation.
- peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
- attention_layers_to_update: Which specific layers we apply PEFT methods to.
- quantize: Whether to quantize the model for lower precision training.
313 def __init__( 314 self, 315 model: Sam, 316 rank: Optional[int] = None, 317 peft_module: nn.Module = LoRASurgery, 318 attention_layers_to_update: Union[List[int]] = None, 319 quantize: bool = False, 320 **module_kwargs 321 ): 322 super().__init__() 323 324 if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0): 325 raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.") 326 327 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), ( 328 "Invalid PEFT module" 329 ) 330 331 if attention_layers_to_update: 332 self.peft_layers = attention_layers_to_update 333 else: # Applies PEFT to the image encoder by default 334 self.peft_layers = list(range(len(model.image_encoder.blocks))) 335 336 self.peft_module = peft_module 337 self.peft_blocks = [] 338 339 # Whether to quantize the linear layers to 4 bit precision. 340 # NOTE: This is currently supported for CUDA-supported devices only. 341 if quantize: 342 if not _have_bnb: 343 raise ModuleNotFoundError("Please install 'bitsandbytes'.") 344 345 for name, module in model.image_encoder.named_modules(): 346 if isinstance(module, torch.nn.Linear): 347 *parent_path, layer_name = name.split(".") 348 parent_module = model.image_encoder 349 350 for sub_module in parent_path: 351 parent_module = getattr(parent_module, sub_module) 352 353 # Create the new Linear4bit layer 354 linear_q = bnb.nn.Linear4bit( 355 module.in_features, 356 module.out_features, 357 bias=False if module.bias is None else True, 358 ) 359 # Assign weights and bias to the new layer 360 new_weight = bnb.nn.Params4bit( 361 data=module.weight, 362 requires_grad=False, 363 ) 364 linear_q.weight = new_weight 365 if module.bias is not None: 366 linear_q.bias = torch.nn.Parameter(module.bias) 367 368 # Replace the original linear layer with the quantized one 369 setattr(parent_module, layer_name, linear_q) 370 371 # Let's freeze all the pretrained image encoder layers first 372 for param in model.image_encoder.parameters(): 373 param.requires_grad = False 374 375 # Add scale and shift parameters to the patch embedding layers. 376 if issubclass(self.peft_module, SSFSurgery): 377 self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed)) 378 379 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 380 # If we only want specific layers with PEFT instead of all 381 if t_layer_i not in self.peft_layers: 382 continue 383 384 if issubclass(self.peft_module, SelectiveSurgery): 385 self.peft_blocks.append(self.peft_module(block=blk)) 386 else: 387 self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) 388 389 self.peft_blocks = nn.ModuleList(self.peft_blocks) 390 391 self.sam = model
Initialize internal Module state, shared by both nn.Module and ScriptModule.
393 def forward(self, batched_input, multimask_output): 394 return self.sam(batched_input, multimask_output)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile