micro_sam.models.peft_sam

  1import math
  2from typing import List, Union, Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from segment_anything.modeling import Sam
  8
  9try:
 10    import bitsandbytes as bnb
 11    _have_bnb = True
 12except ImportError:
 13    _have_bnb = False
 14
 15
 16class LoRASurgery(nn.Module):
 17    """Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation.
 18
 19    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
 20
 21    In SAM, it is implemented as:
 22    ```python
 23    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 24    B, N, C = x.shape
 25    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
 26    q, k, v = qkv.unbind(0)
 27    ```
 28
 29    Args:
 30        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 31        block: The chosen attention blocks for implementing LoRA.
 32        update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp".
 33    """
 34    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
 35        super().__init__()
 36        # Check whether all values for "update_matrices" are as expected.
 37        if set(update_matrices) - set(["q", "k", "v", "mlp"]):
 38            raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.")
 39
 40        self.block = block
 41        block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices)
 42
 43        if "mlp" in update_matrices:
 44            block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp)
 45
 46    def forward(self, x):
 47        return x
 48
 49
 50class AttentionLoRA(nn.Module):
 51    """Operates on the attention layers only for performing low-rank adaptation.
 52
 53    Args:
 54        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 55        block: The chosen attention blocks for implementing LoRA.
 56        update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v".
 57    """
 58
 59    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
 60        super().__init__()
 61        self.qkv_proj = block
 62        self.dim = self.qkv_proj.in_features
 63        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
 64        self.rank = rank
 65
 66        # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices.
 67        if "q" in update_matrices:
 68            self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
 69            self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
 70
 71        if "v" in update_matrices:
 72            self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
 73            self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
 74
 75        if "k" in update_matrices:
 76            self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False)
 77            self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False)
 78
 79        self.reset_parameters()
 80
 81        block = self
 82
 83    def reset_parameters(self):
 84        if hasattr(self, "w_a_linear_q"):
 85            nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
 86            nn.init.zeros_(self.w_b_linear_q.weight)
 87
 88        if hasattr(self, "w_a_linear_v"):
 89            nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
 90            nn.init.zeros_(self.w_b_linear_v.weight)
 91
 92        if hasattr(self, "w_a_linear_k"):
 93            nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5))
 94            nn.init.zeros_(self.w_b_linear_k.weight)
 95
 96    def forward(self, x):
 97        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 98
 99        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0
100        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0
101        new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0
102        qkv = torch.cat(
103            [
104                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values.
105                qkv[:, :, :, self.dim:-self.dim] + new_k,  # replacing new k values.
106                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values.
107            ], dim=-1
108        )
109
110        return qkv
111
112
113class MLPLoRA(nn.Module):
114    """Operates on the feed forward layers for performing low-rank adaptation.
115
116    Args:
117        rank: The rank of the decomposition matrices for updating weights in each attention layer.
118        mlp_layer: The chosen MLP layer for implementing LoRA.
119    """
120
121    def __init__(self, rank: int, mlp_layer: nn.Module):
122        super().__init__()
123
124        self.mlp_layer = mlp_layer
125        self.rank = rank
126        self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False)
127        self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False)
128        self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False)
129        self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False)
130        self.activation = mlp_layer.act
131
132        self.reset_parameters()
133
134        mlp_layer = self
135
136    def reset_parameters(self):
137        nn.init.kaiming_uniform_(self.w_a_linear_1.weight, a=math.sqrt(5))
138        nn.init.kaiming_uniform_(self.w_a_linear_2.weight, a=math.sqrt(5))
139        nn.init.zeros_(self.w_b_linear_1.weight)
140        nn.init.zeros_(self.w_b_linear_2.weight)
141
142    def forward(self, x):
143        x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x))
144        x = self.activation(x)
145        x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x))
146        return x
147
148
149class FacTSurgery(nn.Module):
150    """Operates on the attention layers for performing factorized attention.
151
152    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
153
154    Args:
155        rank: The rank of the decomposition matrices for updating weights in each attention layer.
156        block: The chosen attention blocks for implementing fact.
157        dropout: The dropout rate for the factorized attention.
158    """
159    def __init__(
160        self,
161        rank: int,
162        block: nn.Module,
163        dropout: Optional[float] = 0.1,
164    ):
165        super().__init__()
166        self.qkv_proj = block.attn.qkv
167        self.dim = self.qkv_proj.in_features
168
169        self.q_FacTs = nn.Linear(rank, rank, bias=False)
170        self.v_FacTs = nn.Linear(rank, rank, bias=False)
171
172        self.dropout = dropout
173        if self.dropout is not None:
174            self.dp_q = nn.Dropout(self.dropout)
175            self.dp_v = nn.Dropout(self.dropout)
176
177        self.FacTu = nn.Linear(self.dim, rank, bias=False)
178        self.FacTv = nn.Linear(rank, self.dim, bias=False)
179
180        block.attn.qkv = self
181
182    def forward(self, x):
183        qkv = self.qkv_proj(x)
184
185        new_q = self.q_FacTs(self.FacTu(x))
186        new_v = self.v_FacTs(self.FacTu(x))
187
188        if self.dropout is not None:
189            new_q = self.dp_q(new_q)
190            new_v = self.dp_v(new_v)
191
192        new_q = self.FacTv(new_q)
193        new_v = self.FacTv(new_v)
194
195        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
196        qkv = torch.cat(
197            [
198                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
199                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
200                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
201            ], dim=-1
202        )
203
204        return qkv
205
206
207class ScaleShiftLayer(nn.Module):
208    def __init__(self, layer, dim):
209        super().__init__()
210        self.layer = layer
211        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
212        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
213        layer = self
214
215    def forward(self, x):
216        x = self.layer(x)
217        assert self.scale.shape == self.shift.shape
218        if x.shape[-1] == self.scale.shape[0]:
219            return x * self.scale + self.shift
220        elif x.shape[1] == self.scale.shape[0]:
221            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
222        else:
223            raise ValueError('Input tensors do not match the shape of the scale factors.')
224
225
226class SSFSurgery(nn.Module):
227    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
228
229    Args:
230        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
231        block: The chosen attention blocks for implementing ssf.
232    """
233    def __init__(self, rank: int, block: nn.Module):
234        super().__init__()
235        self.block = block
236
237        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
238        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
239            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
240            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
241            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
242            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
243            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
244            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
245
246        # If we get the embedding block, add one ScaleShiftLayer
247        elif hasattr(block, "patch_embed"):
248            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
249
250    def forward(self, x):
251        return x
252
253
254class SelectiveSurgery(nn.Module):
255    """Base class for selectively allowing gradient updates for certain parameters.
256    """
257    def __init__(self, block: nn.Module):
258        super().__init__()
259        self.block = block
260
261    def allow_gradient_update_for_parameters(
262        self,
263        prefix: Optional[List[str]] = None,
264        suffix: Optional[List[str]] = None,
265        infix: Optional[List[str]] = None,
266    ):
267        """This function decides the parameter attributes to match for allowing gradient updates.
268
269        Args:
270            prefix: Matches the part of parameter name in front.
271            suffix: Matches the part of parameter name at the end.
272            infix: Matches parts of parameter name occuring in between.
273        """
274        for k, v in self.block.named_parameters():
275            if prefix is not None and k.startswith(tuple(prefix)):
276                v.requires_grad = True
277
278            if suffix is not None and k.endswith(tuple(suffix)):
279                v.requires_grad = True
280
281            if infix is not None:
282                for per_infix in infix:
283                    if k.find(per_infix) != -1:
284                        v.requires_grad = True
285
286    def forward(self, x):
287        return x
288
289
290class AdaptFormer(nn.Module):
291    """Adds AdaptFormer Module in place of the MLP Layers
292
293    Args:
294        rank: The rank is not used in this class but kept here for consistency.
295        block: The chosen encoder block for implementing AdaptFormer.
296        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
297        dropout: The dropout rate for the dropout layer between down and up projection layer.
298        projection_size: The size of the projection layer.
299    """
300    def __init__(
301        self,
302        rank: int,
303        block: nn.Module,
304        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
305        dropout: Optional[float] = None,  # Does not have an obvious advantage.
306        projection_size: int = 64,  # Stable choice from our preliminary exp.
307    ):
308        super().__init__()
309
310        self.mlp_proj = block.mlp
311        self.n_embd = block.mlp.lin1.in_features
312
313        if alpha == 'learnable_scalar':
314            self.alpha = nn.Parameter(torch.ones(1))
315        else:
316            self.alpha = alpha
317
318        self.projection_size = projection_size
319        self.dropout = dropout
320
321        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
322        self.non_linear_func = nn.ReLU()
323        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
324
325        block.mlp = self
326
327        if self.dropout is not None:
328            self.dropout_layer = nn.Dropout(self.dropout)
329
330        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
331        nn.init.zeros_(self.up_proj.weight)
332        nn.init.zeros_(self.down_proj.bias)
333        nn.init.zeros_(self.up_proj.bias)
334
335    def forward(self, x):
336        residual = x
337        mlp_output = self.mlp_proj(x)
338
339        down = self.down_proj(x)
340        down = self.non_linear_func(down)
341
342        if self.dropout is not None:
343            down = self.dropout_layer(down)
344
345        up = self.up_proj(down)
346        up = up * self.alpha
347        output = up + residual + mlp_output
348
349        return output
350
351
352class AttentionSurgery(SelectiveSurgery):
353    """Child class for allowing gradient updates for parameters in attention layers."""
354
355    def __init__(self, block: nn.Module):
356        super().__init__(block=block)
357        # Allow gradient updates for the attention layers in the image encoder.
358        self.allow_gradient_update_for_parameters(prefix=["attn"])
359
360
361class BiasSurgery(SelectiveSurgery):
362    """Child class for allowing gradient updates for bias parameters."""
363
364    def __init__(self, block: nn.Module):
365        super().__init__(block=block)
366        # Allow gradient updates for the bias parameters in the image encoder.
367        self.allow_gradient_update_for_parameters(suffix=["bias"])
368
369
370class LayerNormSurgery(SelectiveSurgery):
371    """Child class for allowing gradient updates in normalization layers."""
372
373    def __init__(self, block: nn.Module):
374        super().__init__(block=block)
375        # Allow gradient updates for the LayerNorm parameters in the image encoder.
376        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
377
378
379class ClassicalSurgery(SelectiveSurgery):
380    """Child class for freezing specific blocks."""
381
382    def __init__(self, block: nn.Module):
383        super().__init__(block=block)
384        self.block = block
385
386        for k, v in self.block.named_parameters():
387            v.requires_grad = True
388
389    def forward(self, x):
390        return x
391
392
393class PEFT_Sam(nn.Module):
394    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
395
396    Args:
397        model: The Segment Anything model.
398        rank: The rank for low-rank adaptation.
399        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
400        attention_layers_to_update: Which specific layers we apply PEFT methods to.
401            For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32.
402        quantize: Whether to quantize the model for lower precision training.
403        module_kwargs: The additional arguments for the respective PEFT modules.
404    """
405
406    def __init__(
407        self,
408        model: Sam,
409        rank: Optional[int] = None,
410        peft_module: nn.Module = LoRASurgery,
411        attention_layers_to_update: Optional[List[int]] = None,
412        quantize: bool = False,
413        **module_kwargs
414    ):
415        super().__init__()
416
417        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
418            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
419
420        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
421            "Invalid PEFT module"
422        )
423        if attention_layers_to_update:
424            self.peft_layers = attention_layers_to_update
425        else:   # Applies PEFT to the image encoder by default
426            self.peft_layers = list(range(len(model.image_encoder.blocks)))
427
428        self.peft_module = peft_module
429        self.peft_blocks = []
430
431        # Whether to quantize the linear layers to 4 bit precision.
432        # NOTE: This is currently supported for CUDA-supported devices only.
433        if quantize:
434            if not _have_bnb:
435                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
436
437            for name, module in model.image_encoder.named_modules():
438                if isinstance(module, torch.nn.Linear):
439                    *parent_path, layer_name = name.split(".")
440                    parent_module = model.image_encoder
441
442                    for sub_module in parent_path:
443                        parent_module = getattr(parent_module, sub_module)
444
445                    # Create the new Linear4bit layer
446                    linear_q = bnb.nn.Linear4bit(
447                        module.in_features,
448                        module.out_features,
449                        bias=False if module.bias is None else True,
450                    )
451                    # Assign weights and bias to the new layer
452                    new_weight = bnb.nn.Params4bit(
453                        data=module.weight,
454                        requires_grad=False,
455                    )
456                    linear_q.weight = new_weight
457                    if module.bias is not None:
458                        linear_q.bias = torch.nn.Parameter(module.bias)
459
460                    # Replace the original linear layer with the quantized one
461                    setattr(parent_module, layer_name, linear_q)
462
463        # Let's freeze all the pretrained image encoder layers first
464        for param in model.image_encoder.parameters():
465            param.requires_grad = False
466
467        # Add scale and shift parameters to the patch embedding layers.
468        if issubclass(self.peft_module, SSFSurgery):
469            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
470
471        # If specified, the attention layers to update should match the available blocks.
472        if attention_layers_to_update and (
473            set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks))))
474        ):
475            raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.")
476
477        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
478
479            # If we only want specific layers with PEFT instead of all
480            if t_layer_i not in self.peft_layers:
481                continue
482
483            if issubclass(self.peft_module, SelectiveSurgery):
484                self.peft_blocks.append(self.peft_module(block=blk))
485            else:
486                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
487
488        self.peft_blocks = nn.ModuleList(self.peft_blocks)
489        self.sam = model
490
491    def forward(self, batched_input, multimask_output):
492        return self.sam(batched_input, multimask_output)
class LoRASurgery(torch.nn.modules.module.Module):
17class LoRASurgery(nn.Module):
18    """Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation.
19
20    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
21
22    In SAM, it is implemented as:
23    ```python
24    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
25    B, N, C = x.shape
26    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
27    q, k, v = qkv.unbind(0)
28    ```
29
30    Args:
31        rank: The rank of the decomposition matrices for updating weights in each attention layer.
32        block: The chosen attention blocks for implementing LoRA.
33        update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp".
34    """
35    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
36        super().__init__()
37        # Check whether all values for "update_matrices" are as expected.
38        if set(update_matrices) - set(["q", "k", "v", "mlp"]):
39            raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.")
40
41        self.block = block
42        block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices)
43
44        if "mlp" in update_matrices:
45            block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp)
46
47    def forward(self, x):
48        return x

Operates on the linear layers (attention and/or other feed forward) for performing low-rank adaptation.

(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)

In SAM, it is implemented as:

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing LoRA.
  • update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v", "mlp".
LoRASurgery( rank: int, block: torch.nn.modules.module.Module, update_matrices: List[str] = ['q', 'v'])
35    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
36        super().__init__()
37        # Check whether all values for "update_matrices" are as expected.
38        if set(update_matrices) - set(["q", "k", "v", "mlp"]):
39            raise ValueError(f"Some of the expected keys for updating matrics in '{update_matrices}' are not expected.")
40
41        self.block = block
42        block.attn.qkv = AttentionLoRA(rank=rank, block=block.attn.qkv, update_matrices=update_matrices)
43
44        if "mlp" in update_matrices:
45            block.mlp = MLPLoRA(rank=rank, mlp_layer=block.mlp)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
47    def forward(self, x):
48        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AttentionLoRA(torch.nn.modules.module.Module):
 51class AttentionLoRA(nn.Module):
 52    """Operates on the attention layers only for performing low-rank adaptation.
 53
 54    Args:
 55        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 56        block: The chosen attention blocks for implementing LoRA.
 57        update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v".
 58    """
 59
 60    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
 61        super().__init__()
 62        self.qkv_proj = block
 63        self.dim = self.qkv_proj.in_features
 64        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
 65        self.rank = rank
 66
 67        # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices.
 68        if "q" in update_matrices:
 69            self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
 70            self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
 71
 72        if "v" in update_matrices:
 73            self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
 74            self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
 75
 76        if "k" in update_matrices:
 77            self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False)
 78            self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False)
 79
 80        self.reset_parameters()
 81
 82        block = self
 83
 84    def reset_parameters(self):
 85        if hasattr(self, "w_a_linear_q"):
 86            nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
 87            nn.init.zeros_(self.w_b_linear_q.weight)
 88
 89        if hasattr(self, "w_a_linear_v"):
 90            nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
 91            nn.init.zeros_(self.w_b_linear_v.weight)
 92
 93        if hasattr(self, "w_a_linear_k"):
 94            nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5))
 95            nn.init.zeros_(self.w_b_linear_k.weight)
 96
 97    def forward(self, x):
 98        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 99
100        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0
101        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0
102        new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0
103        qkv = torch.cat(
104            [
105                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values.
106                qkv[:, :, :, self.dim:-self.dim] + new_k,  # replacing new k values.
107                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values.
108            ], dim=-1
109        )
110
111        return qkv

Operates on the attention layers only for performing low-rank adaptation.

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing LoRA.
  • update_matrices: Which specific matrices to update in the attention layer. Choice of "q", "k", "v".
AttentionLoRA( rank: int, block: torch.nn.modules.module.Module, update_matrices: List[str] = ['q', 'v'])
60    def __init__(self, rank: int, block: nn.Module, update_matrices: List[str] = ["q", "v"]):
61        super().__init__()
62        self.qkv_proj = block
63        self.dim = self.qkv_proj.in_features
64        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
65        self.rank = rank
66
67        # By default, we follow LoRA's recommended setup, i.e. update the "q" and "v" matrices.
68        if "q" in update_matrices:
69            self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
70            self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
71
72        if "v" in update_matrices:
73            self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
74            self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
75
76        if "k" in update_matrices:
77            self.w_a_linear_k = nn.Linear(self.dim, self.rank, bias=False)
78            self.w_b_linear_k = nn.Linear(self.rank, self.dim, bias=False)
79
80        self.reset_parameters()
81
82        block = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
alpha
rank
def reset_parameters(self):
84    def reset_parameters(self):
85        if hasattr(self, "w_a_linear_q"):
86            nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
87            nn.init.zeros_(self.w_b_linear_q.weight)
88
89        if hasattr(self, "w_a_linear_v"):
90            nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
91            nn.init.zeros_(self.w_b_linear_v.weight)
92
93        if hasattr(self, "w_a_linear_k"):
94            nn.init.kaiming_uniform_(self.w_a_linear_k.weight, a=math.sqrt(5))
95            nn.init.zeros_(self.w_b_linear_k.weight)
def forward(self, x):
 97    def forward(self, x):
 98        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 99
100        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) if hasattr(self, "w_a_linear_q") else 0
101        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) if hasattr(self, "w_a_linear_v") else 0
102        new_k = self.alpha * self.w_b_linear_k(self.w_a_linear_k(x)) if hasattr(self, "w_a_linear_k") else 0
103        qkv = torch.cat(
104            [
105                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values.
106                qkv[:, :, :, self.dim:-self.dim] + new_k,  # replacing new k values.
107                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values.
108            ], dim=-1
109        )
110
111        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class MLPLoRA(torch.nn.modules.module.Module):
114class MLPLoRA(nn.Module):
115    """Operates on the feed forward layers for performing low-rank adaptation.
116
117    Args:
118        rank: The rank of the decomposition matrices for updating weights in each attention layer.
119        mlp_layer: The chosen MLP layer for implementing LoRA.
120    """
121
122    def __init__(self, rank: int, mlp_layer: nn.Module):
123        super().__init__()
124
125        self.mlp_layer = mlp_layer
126        self.rank = rank
127        self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False)
128        self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False)
129        self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False)
130        self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False)
131        self.activation = mlp_layer.act
132
133        self.reset_parameters()
134
135        mlp_layer = self
136
137    def reset_parameters(self):
138        nn.init.kaiming_uniform_(self.w_a_linear_1.weight, a=math.sqrt(5))
139        nn.init.kaiming_uniform_(self.w_a_linear_2.weight, a=math.sqrt(5))
140        nn.init.zeros_(self.w_b_linear_1.weight)
141        nn.init.zeros_(self.w_b_linear_2.weight)
142
143    def forward(self, x):
144        x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x))
145        x = self.activation(x)
146        x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x))
147        return x

Operates on the feed forward layers for performing low-rank adaptation.

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • mlp_layer: The chosen MLP layer for implementing LoRA.
MLPLoRA(rank: int, mlp_layer: torch.nn.modules.module.Module)
122    def __init__(self, rank: int, mlp_layer: nn.Module):
123        super().__init__()
124
125        self.mlp_layer = mlp_layer
126        self.rank = rank
127        self.w_a_linear_1 = nn.Linear(mlp_layer.lin1.in_features, rank, bias=False)
128        self.w_b_linear_1 = nn.Linear(rank, mlp_layer.lin1.out_features, bias=False)
129        self.w_a_linear_2 = nn.Linear(mlp_layer.lin2.in_features, rank, bias=False)
130        self.w_b_linear_2 = nn.Linear(rank, mlp_layer.lin2.out_features, bias=False)
131        self.activation = mlp_layer.act
132
133        self.reset_parameters()
134
135        mlp_layer = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

mlp_layer
rank
w_a_linear_1
w_b_linear_1
w_a_linear_2
w_b_linear_2
activation
def reset_parameters(self):
137    def reset_parameters(self):
138        nn.init.kaiming_uniform_(self.w_a_linear_1.weight, a=math.sqrt(5))
139        nn.init.kaiming_uniform_(self.w_a_linear_2.weight, a=math.sqrt(5))
140        nn.init.zeros_(self.w_b_linear_1.weight)
141        nn.init.zeros_(self.w_b_linear_2.weight)
def forward(self, x):
143    def forward(self, x):
144        x = self.mlp_layer.lin1(x) + self.w_b_linear_1(self.w_a_linear_1(x))
145        x = self.activation(x)
146        x = self.mlp_layer.lin2(x) + self.w_b_linear_2(self.w_a_linear_2(x))
147        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class FacTSurgery(torch.nn.modules.module.Module):
150class FacTSurgery(nn.Module):
151    """Operates on the attention layers for performing factorized attention.
152
153    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
154
155    Args:
156        rank: The rank of the decomposition matrices for updating weights in each attention layer.
157        block: The chosen attention blocks for implementing fact.
158        dropout: The dropout rate for the factorized attention.
159    """
160    def __init__(
161        self,
162        rank: int,
163        block: nn.Module,
164        dropout: Optional[float] = 0.1,
165    ):
166        super().__init__()
167        self.qkv_proj = block.attn.qkv
168        self.dim = self.qkv_proj.in_features
169
170        self.q_FacTs = nn.Linear(rank, rank, bias=False)
171        self.v_FacTs = nn.Linear(rank, rank, bias=False)
172
173        self.dropout = dropout
174        if self.dropout is not None:
175            self.dp_q = nn.Dropout(self.dropout)
176            self.dp_v = nn.Dropout(self.dropout)
177
178        self.FacTu = nn.Linear(self.dim, rank, bias=False)
179        self.FacTv = nn.Linear(rank, self.dim, bias=False)
180
181        block.attn.qkv = self
182
183    def forward(self, x):
184        qkv = self.qkv_proj(x)
185
186        new_q = self.q_FacTs(self.FacTu(x))
187        new_v = self.v_FacTs(self.FacTu(x))
188
189        if self.dropout is not None:
190            new_q = self.dp_q(new_q)
191            new_v = self.dp_v(new_v)
192
193        new_q = self.FacTv(new_q)
194        new_v = self.FacTv(new_v)
195
196        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
197        qkv = torch.cat(
198            [
199                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
200                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
201                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
202            ], dim=-1
203        )
204
205        return qkv

Operates on the attention layers for performing factorized attention.

(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing fact.
  • dropout: The dropout rate for the factorized attention.
FacTSurgery( rank: int, block: torch.nn.modules.module.Module, dropout: Optional[float] = 0.1)
160    def __init__(
161        self,
162        rank: int,
163        block: nn.Module,
164        dropout: Optional[float] = 0.1,
165    ):
166        super().__init__()
167        self.qkv_proj = block.attn.qkv
168        self.dim = self.qkv_proj.in_features
169
170        self.q_FacTs = nn.Linear(rank, rank, bias=False)
171        self.v_FacTs = nn.Linear(rank, rank, bias=False)
172
173        self.dropout = dropout
174        if self.dropout is not None:
175            self.dp_q = nn.Dropout(self.dropout)
176            self.dp_v = nn.Dropout(self.dropout)
177
178        self.FacTu = nn.Linear(self.dim, rank, bias=False)
179        self.FacTv = nn.Linear(rank, self.dim, bias=False)
180
181        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
q_FacTs
v_FacTs
dropout
FacTu
FacTv
def forward(self, x):
183    def forward(self, x):
184        qkv = self.qkv_proj(x)
185
186        new_q = self.q_FacTs(self.FacTu(x))
187        new_v = self.v_FacTs(self.FacTu(x))
188
189        if self.dropout is not None:
190            new_q = self.dp_q(new_q)
191            new_v = self.dp_v(new_v)
192
193        new_q = self.FacTv(new_q)
194        new_v = self.FacTv(new_v)
195
196        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
197        qkv = torch.cat(
198            [
199                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
200                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
201                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
202            ], dim=-1
203        )
204
205        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ScaleShiftLayer(torch.nn.modules.module.Module):
208class ScaleShiftLayer(nn.Module):
209    def __init__(self, layer, dim):
210        super().__init__()
211        self.layer = layer
212        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
213        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
214        layer = self
215
216    def forward(self, x):
217        x = self.layer(x)
218        assert self.scale.shape == self.shift.shape
219        if x.shape[-1] == self.scale.shape[0]:
220            return x * self.scale + self.shift
221        elif x.shape[1] == self.scale.shape[0]:
222            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
223        else:
224            raise ValueError('Input tensors do not match the shape of the scale factors.')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ScaleShiftLayer(layer, dim)
209    def __init__(self, layer, dim):
210        super().__init__()
211        self.layer = layer
212        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
213        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
214        layer = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer
scale
shift
def forward(self, x):
216    def forward(self, x):
217        x = self.layer(x)
218        assert self.scale.shape == self.shift.shape
219        if x.shape[-1] == self.scale.shape[0]:
220            return x * self.scale + self.shift
221        elif x.shape[1] == self.scale.shape[0]:
222            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
223        else:
224            raise ValueError('Input tensors do not match the shape of the scale factors.')

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SSFSurgery(torch.nn.modules.module.Module):
227class SSFSurgery(nn.Module):
228    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
229
230    Args:
231        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
232        block: The chosen attention blocks for implementing ssf.
233    """
234    def __init__(self, rank: int, block: nn.Module):
235        super().__init__()
236        self.block = block
237
238        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
239        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
240            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
241            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
242            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
243            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
244            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
245            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
246
247        # If we get the embedding block, add one ScaleShiftLayer
248        elif hasattr(block, "patch_embed"):
249            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
250
251    def forward(self, x):
252        return x

Operates on all layers in the transformer block for adding learnable scale and shift parameters.

Arguments:
  • rank: This parameter is not used in SSFSurgery. This is kept here for consistency.
  • block: The chosen attention blocks for implementing ssf.
SSFSurgery(rank: int, block: torch.nn.modules.module.Module)
234    def __init__(self, rank: int, block: nn.Module):
235        super().__init__()
236        self.block = block
237
238        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
239        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
240            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
241            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
242            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
243            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
244            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
245            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
246
247        # If we get the embedding block, add one ScaleShiftLayer
248        elif hasattr(block, "patch_embed"):
249            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
251    def forward(self, x):
252        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SelectiveSurgery(torch.nn.modules.module.Module):
255class SelectiveSurgery(nn.Module):
256    """Base class for selectively allowing gradient updates for certain parameters.
257    """
258    def __init__(self, block: nn.Module):
259        super().__init__()
260        self.block = block
261
262    def allow_gradient_update_for_parameters(
263        self,
264        prefix: Optional[List[str]] = None,
265        suffix: Optional[List[str]] = None,
266        infix: Optional[List[str]] = None,
267    ):
268        """This function decides the parameter attributes to match for allowing gradient updates.
269
270        Args:
271            prefix: Matches the part of parameter name in front.
272            suffix: Matches the part of parameter name at the end.
273            infix: Matches parts of parameter name occuring in between.
274        """
275        for k, v in self.block.named_parameters():
276            if prefix is not None and k.startswith(tuple(prefix)):
277                v.requires_grad = True
278
279            if suffix is not None and k.endswith(tuple(suffix)):
280                v.requires_grad = True
281
282            if infix is not None:
283                for per_infix in infix:
284                    if k.find(per_infix) != -1:
285                        v.requires_grad = True
286
287    def forward(self, x):
288        return x

Base class for selectively allowing gradient updates for certain parameters.

SelectiveSurgery(block: torch.nn.modules.module.Module)
258    def __init__(self, block: nn.Module):
259        super().__init__()
260        self.block = block

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def allow_gradient_update_for_parameters( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, infix: Optional[List[str]] = None):
262    def allow_gradient_update_for_parameters(
263        self,
264        prefix: Optional[List[str]] = None,
265        suffix: Optional[List[str]] = None,
266        infix: Optional[List[str]] = None,
267    ):
268        """This function decides the parameter attributes to match for allowing gradient updates.
269
270        Args:
271            prefix: Matches the part of parameter name in front.
272            suffix: Matches the part of parameter name at the end.
273            infix: Matches parts of parameter name occuring in between.
274        """
275        for k, v in self.block.named_parameters():
276            if prefix is not None and k.startswith(tuple(prefix)):
277                v.requires_grad = True
278
279            if suffix is not None and k.endswith(tuple(suffix)):
280                v.requires_grad = True
281
282            if infix is not None:
283                for per_infix in infix:
284                    if k.find(per_infix) != -1:
285                        v.requires_grad = True

This function decides the parameter attributes to match for allowing gradient updates.

Arguments:
  • prefix: Matches the part of parameter name in front.
  • suffix: Matches the part of parameter name at the end.
  • infix: Matches parts of parameter name occuring in between.
def forward(self, x):
287    def forward(self, x):
288        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AdaptFormer(torch.nn.modules.module.Module):
291class AdaptFormer(nn.Module):
292    """Adds AdaptFormer Module in place of the MLP Layers
293
294    Args:
295        rank: The rank is not used in this class but kept here for consistency.
296        block: The chosen encoder block for implementing AdaptFormer.
297        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
298        dropout: The dropout rate for the dropout layer between down and up projection layer.
299        projection_size: The size of the projection layer.
300    """
301    def __init__(
302        self,
303        rank: int,
304        block: nn.Module,
305        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
306        dropout: Optional[float] = None,  # Does not have an obvious advantage.
307        projection_size: int = 64,  # Stable choice from our preliminary exp.
308    ):
309        super().__init__()
310
311        self.mlp_proj = block.mlp
312        self.n_embd = block.mlp.lin1.in_features
313
314        if alpha == 'learnable_scalar':
315            self.alpha = nn.Parameter(torch.ones(1))
316        else:
317            self.alpha = alpha
318
319        self.projection_size = projection_size
320        self.dropout = dropout
321
322        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
323        self.non_linear_func = nn.ReLU()
324        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
325
326        block.mlp = self
327
328        if self.dropout is not None:
329            self.dropout_layer = nn.Dropout(self.dropout)
330
331        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
332        nn.init.zeros_(self.up_proj.weight)
333        nn.init.zeros_(self.down_proj.bias)
334        nn.init.zeros_(self.up_proj.bias)
335
336    def forward(self, x):
337        residual = x
338        mlp_output = self.mlp_proj(x)
339
340        down = self.down_proj(x)
341        down = self.non_linear_func(down)
342
343        if self.dropout is not None:
344            down = self.dropout_layer(down)
345
346        up = self.up_proj(down)
347        up = up * self.alpha
348        output = up + residual + mlp_output
349
350        return output

Adds AdaptFormer Module in place of the MLP Layers

Arguments:
  • rank: The rank is not used in this class but kept here for consistency.
  • block: The chosen encoder block for implementing AdaptFormer.
  • alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
  • dropout: The dropout rate for the dropout layer between down and up projection layer.
  • projection_size: The size of the projection layer.
AdaptFormer( rank: int, block: torch.nn.modules.module.Module, alpha: Union[str, float, NoneType] = 'learnable_scalar', dropout: Optional[float] = None, projection_size: int = 64)
301    def __init__(
302        self,
303        rank: int,
304        block: nn.Module,
305        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
306        dropout: Optional[float] = None,  # Does not have an obvious advantage.
307        projection_size: int = 64,  # Stable choice from our preliminary exp.
308    ):
309        super().__init__()
310
311        self.mlp_proj = block.mlp
312        self.n_embd = block.mlp.lin1.in_features
313
314        if alpha == 'learnable_scalar':
315            self.alpha = nn.Parameter(torch.ones(1))
316        else:
317            self.alpha = alpha
318
319        self.projection_size = projection_size
320        self.dropout = dropout
321
322        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
323        self.non_linear_func = nn.ReLU()
324        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
325
326        block.mlp = self
327
328        if self.dropout is not None:
329            self.dropout_layer = nn.Dropout(self.dropout)
330
331        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
332        nn.init.zeros_(self.up_proj.weight)
333        nn.init.zeros_(self.down_proj.bias)
334        nn.init.zeros_(self.up_proj.bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

mlp_proj
n_embd
projection_size
dropout
down_proj
non_linear_func
up_proj
def forward(self, x):
336    def forward(self, x):
337        residual = x
338        mlp_output = self.mlp_proj(x)
339
340        down = self.down_proj(x)
341        down = self.non_linear_func(down)
342
343        if self.dropout is not None:
344            down = self.dropout_layer(down)
345
346        up = self.up_proj(down)
347        up = up * self.alpha
348        output = up + residual + mlp_output
349
350        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AttentionSurgery(SelectiveSurgery):
353class AttentionSurgery(SelectiveSurgery):
354    """Child class for allowing gradient updates for parameters in attention layers."""
355
356    def __init__(self, block: nn.Module):
357        super().__init__(block=block)
358        # Allow gradient updates for the attention layers in the image encoder.
359        self.allow_gradient_update_for_parameters(prefix=["attn"])

Child class for allowing gradient updates for parameters in attention layers.

AttentionSurgery(block: torch.nn.modules.module.Module)
356    def __init__(self, block: nn.Module):
357        super().__init__(block=block)
358        # Allow gradient updates for the attention layers in the image encoder.
359        self.allow_gradient_update_for_parameters(prefix=["attn"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BiasSurgery(SelectiveSurgery):
362class BiasSurgery(SelectiveSurgery):
363    """Child class for allowing gradient updates for bias parameters."""
364
365    def __init__(self, block: nn.Module):
366        super().__init__(block=block)
367        # Allow gradient updates for the bias parameters in the image encoder.
368        self.allow_gradient_update_for_parameters(suffix=["bias"])

Child class for allowing gradient updates for bias parameters.

BiasSurgery(block: torch.nn.modules.module.Module)
365    def __init__(self, block: nn.Module):
366        super().__init__(block=block)
367        # Allow gradient updates for the bias parameters in the image encoder.
368        self.allow_gradient_update_for_parameters(suffix=["bias"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class LayerNormSurgery(SelectiveSurgery):
371class LayerNormSurgery(SelectiveSurgery):
372    """Child class for allowing gradient updates in normalization layers."""
373
374    def __init__(self, block: nn.Module):
375        super().__init__(block=block)
376        # Allow gradient updates for the LayerNorm parameters in the image encoder.
377        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Child class for allowing gradient updates in normalization layers.

LayerNormSurgery(block: torch.nn.modules.module.Module)
374    def __init__(self, block: nn.Module):
375        super().__init__(block=block)
376        # Allow gradient updates for the LayerNorm parameters in the image encoder.
377        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ClassicalSurgery(SelectiveSurgery):
380class ClassicalSurgery(SelectiveSurgery):
381    """Child class for freezing specific blocks."""
382
383    def __init__(self, block: nn.Module):
384        super().__init__(block=block)
385        self.block = block
386
387        for k, v in self.block.named_parameters():
388            v.requires_grad = True
389
390    def forward(self, x):
391        return x

Child class for freezing specific blocks.

ClassicalSurgery(block: torch.nn.modules.module.Module)
383    def __init__(self, block: nn.Module):
384        super().__init__(block=block)
385        self.block = block
386
387        for k, v in self.block.named_parameters():
388            v.requires_grad = True

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
390    def forward(self, x):
391        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
SelectiveSurgery
allow_gradient_update_for_parameters
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class PEFT_Sam(torch.nn.modules.module.Module):
394class PEFT_Sam(nn.Module):
395    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
396
397    Args:
398        model: The Segment Anything model.
399        rank: The rank for low-rank adaptation.
400        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
401        attention_layers_to_update: Which specific layers we apply PEFT methods to.
402            For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32.
403        quantize: Whether to quantize the model for lower precision training.
404        module_kwargs: The additional arguments for the respective PEFT modules.
405    """
406
407    def __init__(
408        self,
409        model: Sam,
410        rank: Optional[int] = None,
411        peft_module: nn.Module = LoRASurgery,
412        attention_layers_to_update: Optional[List[int]] = None,
413        quantize: bool = False,
414        **module_kwargs
415    ):
416        super().__init__()
417
418        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
419            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
420
421        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
422            "Invalid PEFT module"
423        )
424        if attention_layers_to_update:
425            self.peft_layers = attention_layers_to_update
426        else:   # Applies PEFT to the image encoder by default
427            self.peft_layers = list(range(len(model.image_encoder.blocks)))
428
429        self.peft_module = peft_module
430        self.peft_blocks = []
431
432        # Whether to quantize the linear layers to 4 bit precision.
433        # NOTE: This is currently supported for CUDA-supported devices only.
434        if quantize:
435            if not _have_bnb:
436                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
437
438            for name, module in model.image_encoder.named_modules():
439                if isinstance(module, torch.nn.Linear):
440                    *parent_path, layer_name = name.split(".")
441                    parent_module = model.image_encoder
442
443                    for sub_module in parent_path:
444                        parent_module = getattr(parent_module, sub_module)
445
446                    # Create the new Linear4bit layer
447                    linear_q = bnb.nn.Linear4bit(
448                        module.in_features,
449                        module.out_features,
450                        bias=False if module.bias is None else True,
451                    )
452                    # Assign weights and bias to the new layer
453                    new_weight = bnb.nn.Params4bit(
454                        data=module.weight,
455                        requires_grad=False,
456                    )
457                    linear_q.weight = new_weight
458                    if module.bias is not None:
459                        linear_q.bias = torch.nn.Parameter(module.bias)
460
461                    # Replace the original linear layer with the quantized one
462                    setattr(parent_module, layer_name, linear_q)
463
464        # Let's freeze all the pretrained image encoder layers first
465        for param in model.image_encoder.parameters():
466            param.requires_grad = False
467
468        # Add scale and shift parameters to the patch embedding layers.
469        if issubclass(self.peft_module, SSFSurgery):
470            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
471
472        # If specified, the attention layers to update should match the available blocks.
473        if attention_layers_to_update and (
474            set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks))))
475        ):
476            raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.")
477
478        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
479
480            # If we only want specific layers with PEFT instead of all
481            if t_layer_i not in self.peft_layers:
482                continue
483
484            if issubclass(self.peft_module, SelectiveSurgery):
485                self.peft_blocks.append(self.peft_module(block=blk))
486            else:
487                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
488
489        self.peft_blocks = nn.ModuleList(self.peft_blocks)
490        self.sam = model
491
492    def forward(self, batched_input, multimask_output):
493        return self.sam(batched_input, multimask_output)

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Arguments:
  • model: The Segment Anything model.
  • rank: The rank for low-rank adaptation.
  • peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
  • attention_layers_to_update: Which specific layers we apply PEFT methods to. For reference, the total number of blocks for 'vit_b' is 12, for 'vit_l' is 24 and for 'vit_h' is 32.
  • quantize: Whether to quantize the model for lower precision training.
  • module_kwargs: The additional arguments for the respective PEFT modules.
PEFT_Sam( model: segment_anything.modeling.sam.Sam, rank: Optional[int] = None, peft_module: torch.nn.modules.module.Module = <class 'LoRASurgery'>, attention_layers_to_update: Optional[List[int]] = None, quantize: bool = False, **module_kwargs)
407    def __init__(
408        self,
409        model: Sam,
410        rank: Optional[int] = None,
411        peft_module: nn.Module = LoRASurgery,
412        attention_layers_to_update: Optional[List[int]] = None,
413        quantize: bool = False,
414        **module_kwargs
415    ):
416        super().__init__()
417
418        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
419            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
420
421        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
422            "Invalid PEFT module"
423        )
424        if attention_layers_to_update:
425            self.peft_layers = attention_layers_to_update
426        else:   # Applies PEFT to the image encoder by default
427            self.peft_layers = list(range(len(model.image_encoder.blocks)))
428
429        self.peft_module = peft_module
430        self.peft_blocks = []
431
432        # Whether to quantize the linear layers to 4 bit precision.
433        # NOTE: This is currently supported for CUDA-supported devices only.
434        if quantize:
435            if not _have_bnb:
436                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
437
438            for name, module in model.image_encoder.named_modules():
439                if isinstance(module, torch.nn.Linear):
440                    *parent_path, layer_name = name.split(".")
441                    parent_module = model.image_encoder
442
443                    for sub_module in parent_path:
444                        parent_module = getattr(parent_module, sub_module)
445
446                    # Create the new Linear4bit layer
447                    linear_q = bnb.nn.Linear4bit(
448                        module.in_features,
449                        module.out_features,
450                        bias=False if module.bias is None else True,
451                    )
452                    # Assign weights and bias to the new layer
453                    new_weight = bnb.nn.Params4bit(
454                        data=module.weight,
455                        requires_grad=False,
456                    )
457                    linear_q.weight = new_weight
458                    if module.bias is not None:
459                        linear_q.bias = torch.nn.Parameter(module.bias)
460
461                    # Replace the original linear layer with the quantized one
462                    setattr(parent_module, layer_name, linear_q)
463
464        # Let's freeze all the pretrained image encoder layers first
465        for param in model.image_encoder.parameters():
466            param.requires_grad = False
467
468        # Add scale and shift parameters to the patch embedding layers.
469        if issubclass(self.peft_module, SSFSurgery):
470            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
471
472        # If specified, the attention layers to update should match the available blocks.
473        if attention_layers_to_update and (
474            set(attention_layers_to_update) - set(list(range(len(model.image_encoder.blocks))))
475        ):
476            raise ValueError("The chosen layer(s) to apply PEFT method is not a valid transformer block id.")
477
478        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
479
480            # If we only want specific layers with PEFT instead of all
481            if t_layer_i not in self.peft_layers:
482                continue
483
484            if issubclass(self.peft_module, SelectiveSurgery):
485                self.peft_blocks.append(self.peft_module(block=blk))
486            else:
487                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
488
489        self.peft_blocks = nn.ModuleList(self.peft_blocks)
490        self.sam = model

Initialize internal Module state, shared by both nn.Module and ScriptModule.

peft_module
peft_blocks
sam
def forward(self, batched_input, multimask_output):
492    def forward(self, batched_input, multimask_output):
493        return self.sam(batched_input, multimask_output)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile