micro_sam.models.peft_sam

  1import math
  2from typing import List, Union, Optional
  3
  4import torch.nn as nn
  5
  6from segment_anything.modeling import Sam
  7
  8
  9class LoRASurgery(nn.Module):
 10    """Operates on the attention layers for performing low-rank adaptation.
 11
 12    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
 13
 14    In SAM, it is implemented as:
 15    ```python
 16    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 17    B, N, C = x.shape
 18    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
 19    q, k, v = qkv.unbind(0)
 20    ```
 21
 22    Args:
 23        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 24        block: The chosen attention blocks for implementing lora.
 25    """
 26    def __init__(self, rank: int, block: nn.Module):
 27        super().__init__()
 28        self.qkv_proj = block.attn.qkv
 29        self.dim = self.qkv_proj.in_features
 30
 31        self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
 32        self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
 33        self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
 34        self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)
 35
 36        self.reset_parameters()
 37
 38        block.attn.qkv = self
 39
 40    def reset_parameters(self):
 41        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
 42        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
 43        nn.init.zeros_(self.w_b_linear_q.weight)
 44        nn.init.zeros_(self.w_b_linear_v.weight)
 45
 46    def forward(self, x):
 47        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 48        new_q = self.w_b_linear_q(self.w_a_linear_q(x))
 49        new_v = self.w_b_linear_v(self.w_a_linear_v(x))
 50        qkv[:, :, :, :self.dim] += new_q
 51        qkv[:, :, :, -self.dim:] += new_v
 52        return qkv
 53
 54
 55class FacTSurgery(nn.Module):
 56    """Operates on the attention layers for performing factorized attention.
 57
 58    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 59
 60    Args:
 61        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 62        block: The chosen attention blocks for implementing fact.
 63        dropout: The dropout rate for the factorized attention.
 64    """
 65    def __init__(
 66        self,
 67        rank: int,
 68        block: nn.Module,
 69        dropout: Optional[float] = 0.1,
 70    ):
 71        super().__init__()
 72        self.qkv_proj = block.attn.qkv
 73        self.dim = self.qkv_proj.in_features
 74
 75        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 76        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 77
 78        self.dropout = dropout
 79        if self.dropout is not None:
 80            self.dp_q = nn.Dropout(self.dropout)
 81            self.dp_v = nn.Dropout(self.dropout)
 82
 83        self.FacTu = nn.Linear(self.dim, rank, bias=False)
 84        self.FacTv = nn.Linear(rank, self.dim, bias=False)
 85
 86        block.attn.qkv = self
 87
 88    def forward(self, x):
 89        qkv = self.qkv_proj(x)
 90
 91        new_q = self.q_FacTs(self.FacTu(x))
 92        new_v = self.v_FacTs(self.FacTu(x))
 93
 94        if self.dropout is not None:
 95            new_q = self.dp_q(new_q)
 96            new_v = self.dp_v(new_v)
 97
 98        new_q = self.FacTv(new_q)
 99        new_v = self.FacTv(new_v)
100
101        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
102        qkv[:, :, :, : self.dim] += new_q
103        qkv[:, :, :, -self.dim:] += new_v
104
105        return qkv
106
107
108class SelectiveSurgery(nn.Module):
109    """Base class for selectively allowing gradient updates for certain parameters.
110    """
111    def __init__(self, block: nn.Module):
112        super().__init__()
113        self.block = block
114
115    def allow_gradient_update_for_parameters(
116        self,
117        prefix: Optional[List[str]] = None,
118        suffix: Optional[List[str]] = None,
119        infix: Optional[List[str]] = None,
120    ):
121        """This function decides the parameter attributes to match for allowing gradient updates.
122
123        Args:
124            prefix: Matches the part of parameter name in front.
125            suffix: Matches the part of parameter name at the end.
126            infix: Matches parts of parameter name occuring in between. 
127        """
128        for k, v in self.block.named_parameters():
129            if prefix is not None and k.startswith(tuple(prefix)):
130                v.requires_grad = True
131
132            if suffix is not None and k.endswith(tuple(suffix)):
133                v.requires_grad = True
134
135            if infix is not None:
136                for per_infix in infix:
137                    if k.find(per_infix) != -1:
138                        v.requires_grad = True
139
140    def forward(self, x):
141        return x
142
143
144class AttentionSurgery(SelectiveSurgery):
145    """Child class for allowing gradient updates for parameters in attention layers.
146    """
147    def __init__(self, block: nn.Module):
148        super().__init__(block=block)
149        # Allow gradient updates for the attention layers in the image encoder.
150        self.allow_gradient_update_for_parameters(prefix=["attn"])
151
152
153class BiasSurgery(SelectiveSurgery):
154    """Child class for allowing gradient updates for bias parameters.
155    """
156    def __init__(self, block: nn.Module):
157        super().__init__(block=block)
158        # Allow gradient updates for the bias parameters in the image encoder.
159        self.allow_gradient_update_for_parameters(suffix=["bias"])
160
161
162class LayerNormSurgery(SelectiveSurgery):
163    """Child class for allowing gradient updates in normalization layers.
164    """
165    def __init__(self, block: nn.Module):
166        super().__init__(block=block)
167        # Allow gradient updates for the LayerNorm parameters in the image encoder.
168        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
169
170
171class PEFT_Sam(nn.Module):
172    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
173
174    Args:
175        model: The Segment Anything model.
176        rank: The rank for low-rank adaptation.
177        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
178        attention_layers_to_update: Which specific layers we apply PEFT methods to.
179    """
180
181    def __init__(
182        self,
183        model: Sam,
184        rank: int,
185        peft_module: nn.Module = LoRASurgery,
186        attention_layers_to_update: Union[List[int]] = None,
187        **module_kwargs
188    ):
189        super().__init__()
190
191        assert rank > 0
192        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
193
194        if attention_layers_to_update:
195            self.peft_layers = attention_layers_to_update
196        else:   # Applies PEFT to the image encoder by default
197            self.peft_layers = list(range(len(model.image_encoder.blocks)))
198
199        self.peft_module = peft_module
200        self.peft_blocks = []
201
202        # let's freeze all the pretrained image encoder layers first
203        for param in model.image_encoder.parameters():
204            param.requires_grad = False
205
206        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
207            # If we only want specific layers with PEFT instead of all
208            if t_layer_i not in self.peft_layers:
209                continue
210
211            if issubclass(self.peft_module, SelectiveSurgery):
212                peft_block = self.peft_module(block=blk)
213            else:
214                peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)
215
216            self.peft_blocks.append(peft_block)
217
218        self.peft_blocks = nn.ModuleList(self.peft_blocks)
219
220        self.sam = model
221
222    def forward(self, batched_input, multimask_output):
223        return self.sam(batched_input, multimask_output)
class LoRASurgery(torch.nn.modules.module.Module):
10class LoRASurgery(nn.Module):
11    """Operates on the attention layers for performing low-rank adaptation.
12
13    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
14
15    In SAM, it is implemented as:
16    ```python
17    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
18    B, N, C = x.shape
19    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
20    q, k, v = qkv.unbind(0)
21    ```
22
23    Args:
24        rank: The rank of the decomposition matrices for updating weights in each attention layer.
25        block: The chosen attention blocks for implementing lora.
26    """
27    def __init__(self, rank: int, block: nn.Module):
28        super().__init__()
29        self.qkv_proj = block.attn.qkv
30        self.dim = self.qkv_proj.in_features
31
32        self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
33        self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
34        self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
35        self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)
36
37        self.reset_parameters()
38
39        block.attn.qkv = self
40
41    def reset_parameters(self):
42        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
43        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
44        nn.init.zeros_(self.w_b_linear_q.weight)
45        nn.init.zeros_(self.w_b_linear_v.weight)
46
47    def forward(self, x):
48        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
49        new_q = self.w_b_linear_q(self.w_a_linear_q(x))
50        new_v = self.w_b_linear_v(self.w_a_linear_v(x))
51        qkv[:, :, :, :self.dim] += new_q
52        qkv[:, :, :, -self.dim:] += new_v
53        return qkv

Operates on the attention layers for performing low-rank adaptation.

(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)

In SAM, it is implemented as:

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing lora.
LoRASurgery(rank: int, block: torch.nn.modules.module.Module)
27    def __init__(self, rank: int, block: nn.Module):
28        super().__init__()
29        self.qkv_proj = block.attn.qkv
30        self.dim = self.qkv_proj.in_features
31
32        self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
33        self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
34        self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False)
35        self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False)
36
37        self.reset_parameters()
38
39        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
w_a_linear_q
w_b_linear_q
w_a_linear_v
w_b_linear_v
def reset_parameters(self):
41    def reset_parameters(self):
42        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
43        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
44        nn.init.zeros_(self.w_b_linear_q.weight)
45        nn.init.zeros_(self.w_b_linear_v.weight)
def forward(self, x):
47    def forward(self, x):
48        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
49        new_q = self.w_b_linear_q(self.w_a_linear_q(x))
50        new_v = self.w_b_linear_v(self.w_a_linear_v(x))
51        qkv[:, :, :, :self.dim] += new_q
52        qkv[:, :, :, -self.dim:] += new_v
53        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class FacTSurgery(torch.nn.modules.module.Module):
 56class FacTSurgery(nn.Module):
 57    """Operates on the attention layers for performing factorized attention.
 58
 59    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 60
 61    Args:
 62        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 63        block: The chosen attention blocks for implementing fact.
 64        dropout: The dropout rate for the factorized attention.
 65    """
 66    def __init__(
 67        self,
 68        rank: int,
 69        block: nn.Module,
 70        dropout: Optional[float] = 0.1,
 71    ):
 72        super().__init__()
 73        self.qkv_proj = block.attn.qkv
 74        self.dim = self.qkv_proj.in_features
 75
 76        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 77        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 78
 79        self.dropout = dropout
 80        if self.dropout is not None:
 81            self.dp_q = nn.Dropout(self.dropout)
 82            self.dp_v = nn.Dropout(self.dropout)
 83
 84        self.FacTu = nn.Linear(self.dim, rank, bias=False)
 85        self.FacTv = nn.Linear(rank, self.dim, bias=False)
 86
 87        block.attn.qkv = self
 88
 89    def forward(self, x):
 90        qkv = self.qkv_proj(x)
 91
 92        new_q = self.q_FacTs(self.FacTu(x))
 93        new_v = self.v_FacTs(self.FacTu(x))
 94
 95        if self.dropout is not None:
 96            new_q = self.dp_q(new_q)
 97            new_v = self.dp_v(new_v)
 98
 99        new_q = self.FacTv(new_q)
100        new_v = self.FacTv(new_v)
101
102        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
103        qkv[:, :, :, : self.dim] += new_q
104        qkv[:, :, :, -self.dim:] += new_v
105
106        return qkv

Operates on the attention layers for performing factorized attention.

(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing fact.
  • dropout: The dropout rate for the factorized attention.
FacTSurgery( rank: int, block: torch.nn.modules.module.Module, dropout: Optional[float] = 0.1)
66    def __init__(
67        self,
68        rank: int,
69        block: nn.Module,
70        dropout: Optional[float] = 0.1,
71    ):
72        super().__init__()
73        self.qkv_proj = block.attn.qkv
74        self.dim = self.qkv_proj.in_features
75
76        self.q_FacTs = nn.Linear(rank, rank, bias=False)
77        self.v_FacTs = nn.Linear(rank, rank, bias=False)
78
79        self.dropout = dropout
80        if self.dropout is not None:
81            self.dp_q = nn.Dropout(self.dropout)
82            self.dp_v = nn.Dropout(self.dropout)
83
84        self.FacTu = nn.Linear(self.dim, rank, bias=False)
85        self.FacTv = nn.Linear(rank, self.dim, bias=False)
86
87        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
q_FacTs
v_FacTs
dropout
FacTu
FacTv
def forward(self, x):
 89    def forward(self, x):
 90        qkv = self.qkv_proj(x)
 91
 92        new_q = self.q_FacTs(self.FacTu(x))
 93        new_v = self.v_FacTs(self.FacTu(x))
 94
 95        if self.dropout is not None:
 96            new_q = self.dp_q(new_q)
 97            new_v = self.dp_v(new_v)
 98
 99        new_q = self.FacTv(new_q)
100        new_v = self.FacTv(new_v)
101
102        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
103        qkv[:, :, :, : self.dim] += new_q
104        qkv[:, :, :, -self.dim:] += new_v
105
106        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SelectiveSurgery(torch.nn.modules.module.Module):
109class SelectiveSurgery(nn.Module):
110    """Base class for selectively allowing gradient updates for certain parameters.
111    """
112    def __init__(self, block: nn.Module):
113        super().__init__()
114        self.block = block
115
116    def allow_gradient_update_for_parameters(
117        self,
118        prefix: Optional[List[str]] = None,
119        suffix: Optional[List[str]] = None,
120        infix: Optional[List[str]] = None,
121    ):
122        """This function decides the parameter attributes to match for allowing gradient updates.
123
124        Args:
125            prefix: Matches the part of parameter name in front.
126            suffix: Matches the part of parameter name at the end.
127            infix: Matches parts of parameter name occuring in between. 
128        """
129        for k, v in self.block.named_parameters():
130            if prefix is not None and k.startswith(tuple(prefix)):
131                v.requires_grad = True
132
133            if suffix is not None and k.endswith(tuple(suffix)):
134                v.requires_grad = True
135
136            if infix is not None:
137                for per_infix in infix:
138                    if k.find(per_infix) != -1:
139                        v.requires_grad = True
140
141    def forward(self, x):
142        return x

Base class for selectively allowing gradient updates for certain parameters.

SelectiveSurgery(block: torch.nn.modules.module.Module)
112    def __init__(self, block: nn.Module):
113        super().__init__()
114        self.block = block

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def allow_gradient_update_for_parameters( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, infix: Optional[List[str]] = None):
116    def allow_gradient_update_for_parameters(
117        self,
118        prefix: Optional[List[str]] = None,
119        suffix: Optional[List[str]] = None,
120        infix: Optional[List[str]] = None,
121    ):
122        """This function decides the parameter attributes to match for allowing gradient updates.
123
124        Args:
125            prefix: Matches the part of parameter name in front.
126            suffix: Matches the part of parameter name at the end.
127            infix: Matches parts of parameter name occuring in between. 
128        """
129        for k, v in self.block.named_parameters():
130            if prefix is not None and k.startswith(tuple(prefix)):
131                v.requires_grad = True
132
133            if suffix is not None and k.endswith(tuple(suffix)):
134                v.requires_grad = True
135
136            if infix is not None:
137                for per_infix in infix:
138                    if k.find(per_infix) != -1:
139                        v.requires_grad = True

This function decides the parameter attributes to match for allowing gradient updates.

Arguments:
  • prefix: Matches the part of parameter name in front.
  • suffix: Matches the part of parameter name at the end.
  • infix: Matches parts of parameter name occuring in between.
def forward(self, x):
141    def forward(self, x):
142        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AttentionSurgery(SelectiveSurgery):
145class AttentionSurgery(SelectiveSurgery):
146    """Child class for allowing gradient updates for parameters in attention layers.
147    """
148    def __init__(self, block: nn.Module):
149        super().__init__(block=block)
150        # Allow gradient updates for the attention layers in the image encoder.
151        self.allow_gradient_update_for_parameters(prefix=["attn"])

Child class for allowing gradient updates for parameters in attention layers.

AttentionSurgery(block: torch.nn.modules.module.Module)
148    def __init__(self, block: nn.Module):
149        super().__init__(block=block)
150        # Allow gradient updates for the attention layers in the image encoder.
151        self.allow_gradient_update_for_parameters(prefix=["attn"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BiasSurgery(SelectiveSurgery):
154class BiasSurgery(SelectiveSurgery):
155    """Child class for allowing gradient updates for bias parameters.
156    """
157    def __init__(self, block: nn.Module):
158        super().__init__(block=block)
159        # Allow gradient updates for the bias parameters in the image encoder.
160        self.allow_gradient_update_for_parameters(suffix=["bias"])

Child class for allowing gradient updates for bias parameters.

BiasSurgery(block: torch.nn.modules.module.Module)
157    def __init__(self, block: nn.Module):
158        super().__init__(block=block)
159        # Allow gradient updates for the bias parameters in the image encoder.
160        self.allow_gradient_update_for_parameters(suffix=["bias"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class LayerNormSurgery(SelectiveSurgery):
163class LayerNormSurgery(SelectiveSurgery):
164    """Child class for allowing gradient updates in normalization layers.
165    """
166    def __init__(self, block: nn.Module):
167        super().__init__(block=block)
168        # Allow gradient updates for the LayerNorm parameters in the image encoder.
169        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Child class for allowing gradient updates in normalization layers.

LayerNormSurgery(block: torch.nn.modules.module.Module)
166    def __init__(self, block: nn.Module):
167        super().__init__(block=block)
168        # Allow gradient updates for the LayerNorm parameters in the image encoder.
169        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class PEFT_Sam(torch.nn.modules.module.Module):
172class PEFT_Sam(nn.Module):
173    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
174
175    Args:
176        model: The Segment Anything model.
177        rank: The rank for low-rank adaptation.
178        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
179        attention_layers_to_update: Which specific layers we apply PEFT methods to.
180    """
181
182    def __init__(
183        self,
184        model: Sam,
185        rank: int,
186        peft_module: nn.Module = LoRASurgery,
187        attention_layers_to_update: Union[List[int]] = None,
188        **module_kwargs
189    ):
190        super().__init__()
191
192        assert rank > 0
193        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
194
195        if attention_layers_to_update:
196            self.peft_layers = attention_layers_to_update
197        else:   # Applies PEFT to the image encoder by default
198            self.peft_layers = list(range(len(model.image_encoder.blocks)))
199
200        self.peft_module = peft_module
201        self.peft_blocks = []
202
203        # let's freeze all the pretrained image encoder layers first
204        for param in model.image_encoder.parameters():
205            param.requires_grad = False
206
207        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
208            # If we only want specific layers with PEFT instead of all
209            if t_layer_i not in self.peft_layers:
210                continue
211
212            if issubclass(self.peft_module, SelectiveSurgery):
213                peft_block = self.peft_module(block=blk)
214            else:
215                peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)
216
217            self.peft_blocks.append(peft_block)
218
219        self.peft_blocks = nn.ModuleList(self.peft_blocks)
220
221        self.sam = model
222
223    def forward(self, batched_input, multimask_output):
224        return self.sam(batched_input, multimask_output)

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Arguments:
  • model: The Segment Anything model.
  • rank: The rank for low-rank adaptation.
  • peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
  • attention_layers_to_update: Which specific layers we apply PEFT methods to.
PEFT_Sam( model: segment_anything.modeling.sam.Sam, rank: int, peft_module: torch.nn.modules.module.Module = <class 'LoRASurgery'>, attention_layers_to_update: List[int] = None, **module_kwargs)
182    def __init__(
183        self,
184        model: Sam,
185        rank: int,
186        peft_module: nn.Module = LoRASurgery,
187        attention_layers_to_update: Union[List[int]] = None,
188        **module_kwargs
189    ):
190        super().__init__()
191
192        assert rank > 0
193        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
194
195        if attention_layers_to_update:
196            self.peft_layers = attention_layers_to_update
197        else:   # Applies PEFT to the image encoder by default
198            self.peft_layers = list(range(len(model.image_encoder.blocks)))
199
200        self.peft_module = peft_module
201        self.peft_blocks = []
202
203        # let's freeze all the pretrained image encoder layers first
204        for param in model.image_encoder.parameters():
205            param.requires_grad = False
206
207        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
208            # If we only want specific layers with PEFT instead of all
209            if t_layer_i not in self.peft_layers:
210                continue
211
212            if issubclass(self.peft_module, SelectiveSurgery):
213                peft_block = self.peft_module(block=blk)
214            else:
215                peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)
216
217            self.peft_blocks.append(peft_block)
218
219        self.peft_blocks = nn.ModuleList(self.peft_blocks)
220
221        self.sam = model

Initialize internal Module state, shared by both nn.Module and ScriptModule.

peft_module
peft_blocks
sam
def forward(self, batched_input, multimask_output):
223    def forward(self, batched_input, multimask_output):
224        return self.sam(batched_input, multimask_output)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile