micro_sam.models.peft_sam

  1import math
  2from typing import List, Union, Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from segment_anything.modeling import Sam
  8
  9try:
 10    import bitsandbytes as bnb
 11    _have_bnb = True
 12except ImportError:
 13    _have_bnb = False
 14
 15
 16class LoRASurgery(nn.Module):
 17    """Operates on the attention layers for performing low-rank adaptation.
 18
 19    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
 20
 21    In SAM, it is implemented as:
 22    ```python
 23    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 24    B, N, C = x.shape
 25    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
 26    q, k, v = qkv.unbind(0)
 27    ```
 28
 29    Args:
 30        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 31        block: The chosen attention blocks for implementing LoRA.
 32    """
 33    def __init__(self, rank: int, block: nn.Module):
 34        super().__init__()
 35        self.qkv_proj = block.attn.qkv
 36        self.dim = self.qkv_proj.in_features
 37        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
 38        self.rank = rank
 39
 40        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
 41        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
 42        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
 43        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
 44
 45        self.reset_parameters()
 46
 47        block.attn.qkv = self
 48
 49    def reset_parameters(self):
 50        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
 51        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
 52        nn.init.zeros_(self.w_b_linear_q.weight)
 53        nn.init.zeros_(self.w_b_linear_v.weight)
 54
 55    def forward(self, x):
 56        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 57        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
 58        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
 59        qkv = torch.cat(
 60            [
 61                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
 62                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
 63                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
 64            ], dim=-1
 65        )
 66
 67        return qkv
 68
 69
 70class FacTSurgery(nn.Module):
 71    """Operates on the attention layers for performing factorized attention.
 72
 73    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 74
 75    Args:
 76        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 77        block: The chosen attention blocks for implementing fact.
 78        dropout: The dropout rate for the factorized attention.
 79    """
 80    def __init__(
 81        self,
 82        rank: int,
 83        block: nn.Module,
 84        dropout: Optional[float] = 0.1,
 85    ):
 86        super().__init__()
 87        self.qkv_proj = block.attn.qkv
 88        self.dim = self.qkv_proj.in_features
 89
 90        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 91        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 92
 93        self.dropout = dropout
 94        if self.dropout is not None:
 95            self.dp_q = nn.Dropout(self.dropout)
 96            self.dp_v = nn.Dropout(self.dropout)
 97
 98        self.FacTu = nn.Linear(self.dim, rank, bias=False)
 99        self.FacTv = nn.Linear(rank, self.dim, bias=False)
100
101        block.attn.qkv = self
102
103    def forward(self, x):
104        qkv = self.qkv_proj(x)
105
106        new_q = self.q_FacTs(self.FacTu(x))
107        new_v = self.v_FacTs(self.FacTu(x))
108
109        if self.dropout is not None:
110            new_q = self.dp_q(new_q)
111            new_v = self.dp_v(new_v)
112
113        new_q = self.FacTv(new_q)
114        new_v = self.FacTv(new_v)
115
116        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
117        qkv = torch.cat(
118            [
119                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
120                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
121                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
122            ], dim=-1
123        )
124
125        return qkv
126
127
128class ScaleShiftLayer(nn.Module):
129    def __init__(self, layer, dim):
130        super().__init__()
131        self.layer = layer
132        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
133        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
134        layer = self
135
136    def forward(self, x):
137        x = self.layer(x)
138        assert self.scale.shape == self.shift.shape
139        if x.shape[-1] == self.scale.shape[0]:
140            return x * self.scale + self.shift
141        elif x.shape[1] == self.scale.shape[0]:
142            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
143        else:
144            raise ValueError('Input tensors do not match the shape of the scale factors.')
145
146
147class SSFSurgery(nn.Module):
148    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
149
150    Args:
151        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
152        block: The chosen attention blocks for implementing ssf.
153        dim: The input dimensions determining the shape of scale and shift parameters.
154    """
155    def __init__(self, rank: int, block: nn.Module):
156        super().__init__()
157        self.block = block
158
159        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
160        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
161            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
162            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
163            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
164            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
165            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
166            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
167
168        # If we get the embedding block, add one ScaleShiftLayer
169        elif hasattr(block, "patch_embed"):
170            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
171
172    def forward(self, x):
173        return x
174
175
176class SelectiveSurgery(nn.Module):
177    """Base class for selectively allowing gradient updates for certain parameters.
178    """
179    def __init__(self, block: nn.Module):
180        super().__init__()
181        self.block = block
182
183    def allow_gradient_update_for_parameters(
184        self,
185        prefix: Optional[List[str]] = None,
186        suffix: Optional[List[str]] = None,
187        infix: Optional[List[str]] = None,
188    ):
189        """This function decides the parameter attributes to match for allowing gradient updates.
190
191        Args:
192            prefix: Matches the part of parameter name in front.
193            suffix: Matches the part of parameter name at the end.
194            infix: Matches parts of parameter name occuring in between.
195        """
196        for k, v in self.block.named_parameters():
197            if prefix is not None and k.startswith(tuple(prefix)):
198                v.requires_grad = True
199
200            if suffix is not None and k.endswith(tuple(suffix)):
201                v.requires_grad = True
202
203            if infix is not None:
204                for per_infix in infix:
205                    if k.find(per_infix) != -1:
206                        v.requires_grad = True
207
208    def forward(self, x):
209        return x
210
211
212class AdaptFormer(nn.Module):
213    """Adds AdaptFormer Module in place of the MLP Layers
214
215    Args:
216        rank: The rank is not used in this class but kept here for consistency.
217        block: The chosen encoder block for implementing AdaptFormer.
218        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
219        dropout: The dropout rate for the dropout layer between down and up projection layer.
220        projection_size: The size of the projection layer.
221    """
222    def __init__(
223        self,
224        rank: int,
225        block: nn.Module,
226        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
227        dropout: Optional[float] = None,  # Does not have an obvious advantage.
228        projection_size: int = 64,  # Stable choice from our preliminary exp.
229    ):
230        super().__init__()
231
232        self.mlp_proj = block.mlp
233        self.n_embd = block.mlp.lin1.in_features
234
235        if alpha == 'learnable_scalar':
236            self.alpha = nn.Parameter(torch.ones(1))
237        else:
238            self.alpha = alpha
239
240        self.projection_size = projection_size
241        self.dropout = dropout
242
243        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
244        self.non_linear_func = nn.ReLU()
245        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
246
247        block.mlp = self
248
249        if self.dropout is not None:
250            self.dropout_layer = nn.Dropout(self.dropout)
251
252        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
253        nn.init.zeros_(self.up_proj.weight)
254        nn.init.zeros_(self.down_proj.bias)
255        nn.init.zeros_(self.up_proj.bias)
256
257    def forward(self, x):
258        residual = x
259        mlp_output = self.mlp_proj(x)
260
261        down = self.down_proj(x)
262        down = self.non_linear_func(down)
263
264        if self.dropout is not None:
265            down = self.dropout_layer(down)
266
267        up = self.up_proj(down)
268        up = up * self.alpha
269        output = up + residual + mlp_output
270
271        return output
272
273
274class AttentionSurgery(SelectiveSurgery):
275    """Child class for allowing gradient updates for parameters in attention layers.
276    """
277    def __init__(self, block: nn.Module):
278        super().__init__(block=block)
279        # Allow gradient updates for the attention layers in the image encoder.
280        self.allow_gradient_update_for_parameters(prefix=["attn"])
281
282
283class BiasSurgery(SelectiveSurgery):
284    """Child class for allowing gradient updates for bias parameters.
285    """
286    def __init__(self, block: nn.Module):
287        super().__init__(block=block)
288        # Allow gradient updates for the bias parameters in the image encoder.
289        self.allow_gradient_update_for_parameters(suffix=["bias"])
290
291
292class LayerNormSurgery(SelectiveSurgery):
293    """Child class for allowing gradient updates in normalization layers.
294    """
295    def __init__(self, block: nn.Module):
296        super().__init__(block=block)
297        # Allow gradient updates for the LayerNorm parameters in the image encoder.
298        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
299
300
301class PEFT_Sam(nn.Module):
302    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
303
304    Args:
305        model: The Segment Anything model.
306        rank: The rank for low-rank adaptation.
307        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
308        attention_layers_to_update: Which specific layers we apply PEFT methods to.
309        quantize: Whether to quantize the model for lower precision training.
310    """
311
312    def __init__(
313        self,
314        model: Sam,
315        rank: Optional[int] = None,
316        peft_module: nn.Module = LoRASurgery,
317        attention_layers_to_update: Union[List[int]] = None,
318        quantize: bool = False,
319        **module_kwargs
320    ):
321        super().__init__()
322
323        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
324            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
325
326        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
327            "Invalid PEFT module"
328        )
329
330        if attention_layers_to_update:
331            self.peft_layers = attention_layers_to_update
332        else:   # Applies PEFT to the image encoder by default
333            self.peft_layers = list(range(len(model.image_encoder.blocks)))
334
335        self.peft_module = peft_module
336        self.peft_blocks = []
337
338        # Whether to quantize the linear layers to 4 bit precision.
339        # NOTE: This is currently supported for CUDA-supported devices only.
340        if quantize:
341            if not _have_bnb:
342                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
343
344            for name, module in model.image_encoder.named_modules():
345                if isinstance(module, torch.nn.Linear):
346                    *parent_path, layer_name = name.split(".")
347                    parent_module = model.image_encoder
348
349                    for sub_module in parent_path:
350                        parent_module = getattr(parent_module, sub_module)
351
352                    # Create the new Linear4bit layer
353                    linear_q = bnb.nn.Linear4bit(
354                        module.in_features,
355                        module.out_features,
356                        bias=False if module.bias is None else True,
357                    )
358                    # Assign weights and bias to the new layer
359                    new_weight = bnb.nn.Params4bit(
360                        data=module.weight,
361                        requires_grad=False,
362                    )
363                    linear_q.weight = new_weight
364                    if module.bias is not None:
365                        linear_q.bias = torch.nn.Parameter(module.bias)
366
367                    # Replace the original linear layer with the quantized one
368                    setattr(parent_module, layer_name, linear_q)
369
370        # Let's freeze all the pretrained image encoder layers first
371        for param in model.image_encoder.parameters():
372            param.requires_grad = False
373
374        # Add scale and shift parameters to the patch embedding layers.
375        if issubclass(self.peft_module, SSFSurgery):
376            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
377
378        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
379            # If we only want specific layers with PEFT instead of all
380            if t_layer_i not in self.peft_layers:
381                continue
382
383            if issubclass(self.peft_module, SelectiveSurgery):
384                self.peft_blocks.append(self.peft_module(block=blk))
385            else:
386                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
387
388        self.peft_blocks = nn.ModuleList(self.peft_blocks)
389
390        self.sam = model
391
392    def forward(self, batched_input, multimask_output):
393        return self.sam(batched_input, multimask_output)
class LoRASurgery(torch.nn.modules.module.Module):
17class LoRASurgery(nn.Module):
18    """Operates on the attention layers for performing low-rank adaptation.
19
20    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
21
22    In SAM, it is implemented as:
23    ```python
24    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
25    B, N, C = x.shape
26    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
27    q, k, v = qkv.unbind(0)
28    ```
29
30    Args:
31        rank: The rank of the decomposition matrices for updating weights in each attention layer.
32        block: The chosen attention blocks for implementing LoRA.
33    """
34    def __init__(self, rank: int, block: nn.Module):
35        super().__init__()
36        self.qkv_proj = block.attn.qkv
37        self.dim = self.qkv_proj.in_features
38        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
39        self.rank = rank
40
41        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
42        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
43        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
44        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
45
46        self.reset_parameters()
47
48        block.attn.qkv = self
49
50    def reset_parameters(self):
51        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
52        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
53        nn.init.zeros_(self.w_b_linear_q.weight)
54        nn.init.zeros_(self.w_b_linear_v.weight)
55
56    def forward(self, x):
57        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
58        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
59        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
60        qkv = torch.cat(
61            [
62                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
63                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
64                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
65            ], dim=-1
66        )
67
68        return qkv

Operates on the attention layers for performing low-rank adaptation.

(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)

In SAM, it is implemented as:

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing LoRA.
LoRASurgery(rank: int, block: torch.nn.modules.module.Module)
34    def __init__(self, rank: int, block: nn.Module):
35        super().__init__()
36        self.qkv_proj = block.attn.qkv
37        self.dim = self.qkv_proj.in_features
38        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
39        self.rank = rank
40
41        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
42        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
43        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
44        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
45
46        self.reset_parameters()
47
48        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
alpha
rank
w_a_linear_q
w_b_linear_q
w_a_linear_v
w_b_linear_v
def reset_parameters(self):
50    def reset_parameters(self):
51        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
52        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
53        nn.init.zeros_(self.w_b_linear_q.weight)
54        nn.init.zeros_(self.w_b_linear_v.weight)
def forward(self, x):
56    def forward(self, x):
57        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
58        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
59        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
60        qkv = torch.cat(
61            [
62                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
63                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
64                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
65            ], dim=-1
66        )
67
68        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class FacTSurgery(torch.nn.modules.module.Module):
 71class FacTSurgery(nn.Module):
 72    """Operates on the attention layers for performing factorized attention.
 73
 74    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 75
 76    Args:
 77        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 78        block: The chosen attention blocks for implementing fact.
 79        dropout: The dropout rate for the factorized attention.
 80    """
 81    def __init__(
 82        self,
 83        rank: int,
 84        block: nn.Module,
 85        dropout: Optional[float] = 0.1,
 86    ):
 87        super().__init__()
 88        self.qkv_proj = block.attn.qkv
 89        self.dim = self.qkv_proj.in_features
 90
 91        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 92        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 93
 94        self.dropout = dropout
 95        if self.dropout is not None:
 96            self.dp_q = nn.Dropout(self.dropout)
 97            self.dp_v = nn.Dropout(self.dropout)
 98
 99        self.FacTu = nn.Linear(self.dim, rank, bias=False)
100        self.FacTv = nn.Linear(rank, self.dim, bias=False)
101
102        block.attn.qkv = self
103
104    def forward(self, x):
105        qkv = self.qkv_proj(x)
106
107        new_q = self.q_FacTs(self.FacTu(x))
108        new_v = self.v_FacTs(self.FacTu(x))
109
110        if self.dropout is not None:
111            new_q = self.dp_q(new_q)
112            new_v = self.dp_v(new_v)
113
114        new_q = self.FacTv(new_q)
115        new_v = self.FacTv(new_v)
116
117        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
118        qkv = torch.cat(
119            [
120                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
121                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
122                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
123            ], dim=-1
124        )
125
126        return qkv

Operates on the attention layers for performing factorized attention.

(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing fact.
  • dropout: The dropout rate for the factorized attention.
FacTSurgery( rank: int, block: torch.nn.modules.module.Module, dropout: Optional[float] = 0.1)
 81    def __init__(
 82        self,
 83        rank: int,
 84        block: nn.Module,
 85        dropout: Optional[float] = 0.1,
 86    ):
 87        super().__init__()
 88        self.qkv_proj = block.attn.qkv
 89        self.dim = self.qkv_proj.in_features
 90
 91        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 92        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 93
 94        self.dropout = dropout
 95        if self.dropout is not None:
 96            self.dp_q = nn.Dropout(self.dropout)
 97            self.dp_v = nn.Dropout(self.dropout)
 98
 99        self.FacTu = nn.Linear(self.dim, rank, bias=False)
100        self.FacTv = nn.Linear(rank, self.dim, bias=False)
101
102        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
q_FacTs
v_FacTs
dropout
FacTu
FacTv
def forward(self, x):
104    def forward(self, x):
105        qkv = self.qkv_proj(x)
106
107        new_q = self.q_FacTs(self.FacTu(x))
108        new_v = self.v_FacTs(self.FacTu(x))
109
110        if self.dropout is not None:
111            new_q = self.dp_q(new_q)
112            new_v = self.dp_v(new_v)
113
114        new_q = self.FacTv(new_q)
115        new_v = self.FacTv(new_v)
116
117        # NOTE : Scaling Factor is set to 1 as it can be tuned via the learning rate.
118        qkv = torch.cat(
119            [
120                qkv[:, :, :, :self.dim] + new_q,  # replacing new q values
121                qkv[:, :, :, self.dim:-self.dim],  # leaving the middle part as identical
122                qkv[:, :, :, -self.dim:] + new_v  # replacing new v values
123            ], dim=-1
124        )
125
126        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ScaleShiftLayer(torch.nn.modules.module.Module):
129class ScaleShiftLayer(nn.Module):
130    def __init__(self, layer, dim):
131        super().__init__()
132        self.layer = layer
133        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
134        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
135        layer = self
136
137    def forward(self, x):
138        x = self.layer(x)
139        assert self.scale.shape == self.shift.shape
140        if x.shape[-1] == self.scale.shape[0]:
141            return x * self.scale + self.shift
142        elif x.shape[1] == self.scale.shape[0]:
143            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
144        else:
145            raise ValueError('Input tensors do not match the shape of the scale factors.')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ScaleShiftLayer(layer, dim)
130    def __init__(self, layer, dim):
131        super().__init__()
132        self.layer = layer
133        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
134        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
135        layer = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer
scale
shift
def forward(self, x):
137    def forward(self, x):
138        x = self.layer(x)
139        assert self.scale.shape == self.shift.shape
140        if x.shape[-1] == self.scale.shape[0]:
141            return x * self.scale + self.shift
142        elif x.shape[1] == self.scale.shape[0]:
143            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
144        else:
145            raise ValueError('Input tensors do not match the shape of the scale factors.')

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SSFSurgery(torch.nn.modules.module.Module):
148class SSFSurgery(nn.Module):
149    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
150
151    Args:
152        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
153        block: The chosen attention blocks for implementing ssf.
154        dim: The input dimensions determining the shape of scale and shift parameters.
155    """
156    def __init__(self, rank: int, block: nn.Module):
157        super().__init__()
158        self.block = block
159
160        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
161        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
162            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
163            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
164            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
165            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
166            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
167            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
168
169        # If we get the embedding block, add one ScaleShiftLayer
170        elif hasattr(block, "patch_embed"):
171            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
172
173    def forward(self, x):
174        return x

Operates on all layers in the transformer block for adding learnable scale and shift parameters.

Arguments:
  • rank: This parameter is not used in SSFSurgery. This is kept here for consistency.
  • block: The chosen attention blocks for implementing ssf.
  • dim: The input dimensions determining the shape of scale and shift parameters.
SSFSurgery(rank: int, block: torch.nn.modules.module.Module)
156    def __init__(self, rank: int, block: nn.Module):
157        super().__init__()
158        self.block = block
159
160        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
161        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
162            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
163            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
164            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
165            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
166            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
167            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
168
169        # If we get the embedding block, add one ScaleShiftLayer
170        elif hasattr(block, "patch_embed"):
171            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
173    def forward(self, x):
174        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SelectiveSurgery(torch.nn.modules.module.Module):
177class SelectiveSurgery(nn.Module):
178    """Base class for selectively allowing gradient updates for certain parameters.
179    """
180    def __init__(self, block: nn.Module):
181        super().__init__()
182        self.block = block
183
184    def allow_gradient_update_for_parameters(
185        self,
186        prefix: Optional[List[str]] = None,
187        suffix: Optional[List[str]] = None,
188        infix: Optional[List[str]] = None,
189    ):
190        """This function decides the parameter attributes to match for allowing gradient updates.
191
192        Args:
193            prefix: Matches the part of parameter name in front.
194            suffix: Matches the part of parameter name at the end.
195            infix: Matches parts of parameter name occuring in between.
196        """
197        for k, v in self.block.named_parameters():
198            if prefix is not None and k.startswith(tuple(prefix)):
199                v.requires_grad = True
200
201            if suffix is not None and k.endswith(tuple(suffix)):
202                v.requires_grad = True
203
204            if infix is not None:
205                for per_infix in infix:
206                    if k.find(per_infix) != -1:
207                        v.requires_grad = True
208
209    def forward(self, x):
210        return x

Base class for selectively allowing gradient updates for certain parameters.

SelectiveSurgery(block: torch.nn.modules.module.Module)
180    def __init__(self, block: nn.Module):
181        super().__init__()
182        self.block = block

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def allow_gradient_update_for_parameters( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, infix: Optional[List[str]] = None):
184    def allow_gradient_update_for_parameters(
185        self,
186        prefix: Optional[List[str]] = None,
187        suffix: Optional[List[str]] = None,
188        infix: Optional[List[str]] = None,
189    ):
190        """This function decides the parameter attributes to match for allowing gradient updates.
191
192        Args:
193            prefix: Matches the part of parameter name in front.
194            suffix: Matches the part of parameter name at the end.
195            infix: Matches parts of parameter name occuring in between.
196        """
197        for k, v in self.block.named_parameters():
198            if prefix is not None and k.startswith(tuple(prefix)):
199                v.requires_grad = True
200
201            if suffix is not None and k.endswith(tuple(suffix)):
202                v.requires_grad = True
203
204            if infix is not None:
205                for per_infix in infix:
206                    if k.find(per_infix) != -1:
207                        v.requires_grad = True

This function decides the parameter attributes to match for allowing gradient updates.

Arguments:
  • prefix: Matches the part of parameter name in front.
  • suffix: Matches the part of parameter name at the end.
  • infix: Matches parts of parameter name occuring in between.
def forward(self, x):
209    def forward(self, x):
210        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AdaptFormer(torch.nn.modules.module.Module):
213class AdaptFormer(nn.Module):
214    """Adds AdaptFormer Module in place of the MLP Layers
215
216    Args:
217        rank: The rank is not used in this class but kept here for consistency.
218        block: The chosen encoder block for implementing AdaptFormer.
219        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
220        dropout: The dropout rate for the dropout layer between down and up projection layer.
221        projection_size: The size of the projection layer.
222    """
223    def __init__(
224        self,
225        rank: int,
226        block: nn.Module,
227        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
228        dropout: Optional[float] = None,  # Does not have an obvious advantage.
229        projection_size: int = 64,  # Stable choice from our preliminary exp.
230    ):
231        super().__init__()
232
233        self.mlp_proj = block.mlp
234        self.n_embd = block.mlp.lin1.in_features
235
236        if alpha == 'learnable_scalar':
237            self.alpha = nn.Parameter(torch.ones(1))
238        else:
239            self.alpha = alpha
240
241        self.projection_size = projection_size
242        self.dropout = dropout
243
244        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
245        self.non_linear_func = nn.ReLU()
246        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
247
248        block.mlp = self
249
250        if self.dropout is not None:
251            self.dropout_layer = nn.Dropout(self.dropout)
252
253        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
254        nn.init.zeros_(self.up_proj.weight)
255        nn.init.zeros_(self.down_proj.bias)
256        nn.init.zeros_(self.up_proj.bias)
257
258    def forward(self, x):
259        residual = x
260        mlp_output = self.mlp_proj(x)
261
262        down = self.down_proj(x)
263        down = self.non_linear_func(down)
264
265        if self.dropout is not None:
266            down = self.dropout_layer(down)
267
268        up = self.up_proj(down)
269        up = up * self.alpha
270        output = up + residual + mlp_output
271
272        return output

Adds AdaptFormer Module in place of the MLP Layers

Arguments:
  • rank: The rank is not used in this class but kept here for consistency.
  • block: The chosen encoder block for implementing AdaptFormer.
  • alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
  • dropout: The dropout rate for the dropout layer between down and up projection layer.
  • projection_size: The size of the projection layer.
AdaptFormer( rank: int, block: torch.nn.modules.module.Module, alpha: Union[str, float, NoneType] = 'learnable_scalar', dropout: Optional[float] = None, projection_size: int = 64)
223    def __init__(
224        self,
225        rank: int,
226        block: nn.Module,
227        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
228        dropout: Optional[float] = None,  # Does not have an obvious advantage.
229        projection_size: int = 64,  # Stable choice from our preliminary exp.
230    ):
231        super().__init__()
232
233        self.mlp_proj = block.mlp
234        self.n_embd = block.mlp.lin1.in_features
235
236        if alpha == 'learnable_scalar':
237            self.alpha = nn.Parameter(torch.ones(1))
238        else:
239            self.alpha = alpha
240
241        self.projection_size = projection_size
242        self.dropout = dropout
243
244        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
245        self.non_linear_func = nn.ReLU()
246        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
247
248        block.mlp = self
249
250        if self.dropout is not None:
251            self.dropout_layer = nn.Dropout(self.dropout)
252
253        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
254        nn.init.zeros_(self.up_proj.weight)
255        nn.init.zeros_(self.down_proj.bias)
256        nn.init.zeros_(self.up_proj.bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

mlp_proj
n_embd
projection_size
dropout
down_proj
non_linear_func
up_proj
def forward(self, x):
258    def forward(self, x):
259        residual = x
260        mlp_output = self.mlp_proj(x)
261
262        down = self.down_proj(x)
263        down = self.non_linear_func(down)
264
265        if self.dropout is not None:
266            down = self.dropout_layer(down)
267
268        up = self.up_proj(down)
269        up = up * self.alpha
270        output = up + residual + mlp_output
271
272        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AttentionSurgery(SelectiveSurgery):
275class AttentionSurgery(SelectiveSurgery):
276    """Child class for allowing gradient updates for parameters in attention layers.
277    """
278    def __init__(self, block: nn.Module):
279        super().__init__(block=block)
280        # Allow gradient updates for the attention layers in the image encoder.
281        self.allow_gradient_update_for_parameters(prefix=["attn"])

Child class for allowing gradient updates for parameters in attention layers.

AttentionSurgery(block: torch.nn.modules.module.Module)
278    def __init__(self, block: nn.Module):
279        super().__init__(block=block)
280        # Allow gradient updates for the attention layers in the image encoder.
281        self.allow_gradient_update_for_parameters(prefix=["attn"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BiasSurgery(SelectiveSurgery):
284class BiasSurgery(SelectiveSurgery):
285    """Child class for allowing gradient updates for bias parameters.
286    """
287    def __init__(self, block: nn.Module):
288        super().__init__(block=block)
289        # Allow gradient updates for the bias parameters in the image encoder.
290        self.allow_gradient_update_for_parameters(suffix=["bias"])

Child class for allowing gradient updates for bias parameters.

BiasSurgery(block: torch.nn.modules.module.Module)
287    def __init__(self, block: nn.Module):
288        super().__init__(block=block)
289        # Allow gradient updates for the bias parameters in the image encoder.
290        self.allow_gradient_update_for_parameters(suffix=["bias"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class LayerNormSurgery(SelectiveSurgery):
293class LayerNormSurgery(SelectiveSurgery):
294    """Child class for allowing gradient updates in normalization layers.
295    """
296    def __init__(self, block: nn.Module):
297        super().__init__(block=block)
298        # Allow gradient updates for the LayerNorm parameters in the image encoder.
299        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Child class for allowing gradient updates in normalization layers.

LayerNormSurgery(block: torch.nn.modules.module.Module)
296    def __init__(self, block: nn.Module):
297        super().__init__(block=block)
298        # Allow gradient updates for the LayerNorm parameters in the image encoder.
299        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class PEFT_Sam(torch.nn.modules.module.Module):
302class PEFT_Sam(nn.Module):
303    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
304
305    Args:
306        model: The Segment Anything model.
307        rank: The rank for low-rank adaptation.
308        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
309        attention_layers_to_update: Which specific layers we apply PEFT methods to.
310        quantize: Whether to quantize the model for lower precision training.
311    """
312
313    def __init__(
314        self,
315        model: Sam,
316        rank: Optional[int] = None,
317        peft_module: nn.Module = LoRASurgery,
318        attention_layers_to_update: Union[List[int]] = None,
319        quantize: bool = False,
320        **module_kwargs
321    ):
322        super().__init__()
323
324        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
325            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
326
327        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
328            "Invalid PEFT module"
329        )
330
331        if attention_layers_to_update:
332            self.peft_layers = attention_layers_to_update
333        else:   # Applies PEFT to the image encoder by default
334            self.peft_layers = list(range(len(model.image_encoder.blocks)))
335
336        self.peft_module = peft_module
337        self.peft_blocks = []
338
339        # Whether to quantize the linear layers to 4 bit precision.
340        # NOTE: This is currently supported for CUDA-supported devices only.
341        if quantize:
342            if not _have_bnb:
343                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
344
345            for name, module in model.image_encoder.named_modules():
346                if isinstance(module, torch.nn.Linear):
347                    *parent_path, layer_name = name.split(".")
348                    parent_module = model.image_encoder
349
350                    for sub_module in parent_path:
351                        parent_module = getattr(parent_module, sub_module)
352
353                    # Create the new Linear4bit layer
354                    linear_q = bnb.nn.Linear4bit(
355                        module.in_features,
356                        module.out_features,
357                        bias=False if module.bias is None else True,
358                    )
359                    # Assign weights and bias to the new layer
360                    new_weight = bnb.nn.Params4bit(
361                        data=module.weight,
362                        requires_grad=False,
363                    )
364                    linear_q.weight = new_weight
365                    if module.bias is not None:
366                        linear_q.bias = torch.nn.Parameter(module.bias)
367
368                    # Replace the original linear layer with the quantized one
369                    setattr(parent_module, layer_name, linear_q)
370
371        # Let's freeze all the pretrained image encoder layers first
372        for param in model.image_encoder.parameters():
373            param.requires_grad = False
374
375        # Add scale and shift parameters to the patch embedding layers.
376        if issubclass(self.peft_module, SSFSurgery):
377            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
378
379        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
380            # If we only want specific layers with PEFT instead of all
381            if t_layer_i not in self.peft_layers:
382                continue
383
384            if issubclass(self.peft_module, SelectiveSurgery):
385                self.peft_blocks.append(self.peft_module(block=blk))
386            else:
387                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
388
389        self.peft_blocks = nn.ModuleList(self.peft_blocks)
390
391        self.sam = model
392
393    def forward(self, batched_input, multimask_output):
394        return self.sam(batched_input, multimask_output)

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Arguments:
  • model: The Segment Anything model.
  • rank: The rank for low-rank adaptation.
  • peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
  • attention_layers_to_update: Which specific layers we apply PEFT methods to.
  • quantize: Whether to quantize the model for lower precision training.
PEFT_Sam( model: segment_anything.modeling.sam.Sam, rank: Optional[int] = None, peft_module: torch.nn.modules.module.Module = <class 'LoRASurgery'>, attention_layers_to_update: List[int] = None, quantize: bool = False, **module_kwargs)
313    def __init__(
314        self,
315        model: Sam,
316        rank: Optional[int] = None,
317        peft_module: nn.Module = LoRASurgery,
318        attention_layers_to_update: Union[List[int]] = None,
319        quantize: bool = False,
320        **module_kwargs
321    ):
322        super().__init__()
323
324        if issubclass(peft_module, Union[LoRASurgery, FacTSurgery]) and (not rank or rank <= 0):
325            raise RuntimeError("The chosen PEFT method cannot run without a valid rank choice.")
326
327        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
328            "Invalid PEFT module"
329        )
330
331        if attention_layers_to_update:
332            self.peft_layers = attention_layers_to_update
333        else:   # Applies PEFT to the image encoder by default
334            self.peft_layers = list(range(len(model.image_encoder.blocks)))
335
336        self.peft_module = peft_module
337        self.peft_blocks = []
338
339        # Whether to quantize the linear layers to 4 bit precision.
340        # NOTE: This is currently supported for CUDA-supported devices only.
341        if quantize:
342            if not _have_bnb:
343                raise ModuleNotFoundError("Please install 'bitsandbytes'.")
344
345            for name, module in model.image_encoder.named_modules():
346                if isinstance(module, torch.nn.Linear):
347                    *parent_path, layer_name = name.split(".")
348                    parent_module = model.image_encoder
349
350                    for sub_module in parent_path:
351                        parent_module = getattr(parent_module, sub_module)
352
353                    # Create the new Linear4bit layer
354                    linear_q = bnb.nn.Linear4bit(
355                        module.in_features,
356                        module.out_features,
357                        bias=False if module.bias is None else True,
358                    )
359                    # Assign weights and bias to the new layer
360                    new_weight = bnb.nn.Params4bit(
361                        data=module.weight,
362                        requires_grad=False,
363                    )
364                    linear_q.weight = new_weight
365                    if module.bias is not None:
366                        linear_q.bias = torch.nn.Parameter(module.bias)
367
368                    # Replace the original linear layer with the quantized one
369                    setattr(parent_module, layer_name, linear_q)
370
371        # Let's freeze all the pretrained image encoder layers first
372        for param in model.image_encoder.parameters():
373            param.requires_grad = False
374
375        # Add scale and shift parameters to the patch embedding layers.
376        if issubclass(self.peft_module, SSFSurgery):
377            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
378
379        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
380            # If we only want specific layers with PEFT instead of all
381            if t_layer_i not in self.peft_layers:
382                continue
383
384            if issubclass(self.peft_module, SelectiveSurgery):
385                self.peft_blocks.append(self.peft_module(block=blk))
386            else:
387                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
388
389        self.peft_blocks = nn.ModuleList(self.peft_blocks)
390
391        self.sam = model

Initialize internal Module state, shared by both nn.Module and ScriptModule.

peft_module
peft_blocks
sam
def forward(self, batched_input, multimask_output):
393    def forward(self, batched_input, multimask_output):
394        return self.sam(batched_input, multimask_output)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile