micro_sam.models.peft_sam
1import math 2from typing import List, Union, Optional 3 4import torch.nn as nn 5 6from segment_anything.modeling import Sam 7 8 9class LoRASurgery(nn.Module): 10 """Operates on the attention layers for performing low-rank adaptation. 11 12 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 13 14 In SAM, it is implemented as: 15 ```python 16 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 17 B, N, C = x.shape 18 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 19 q, k, v = qkv.unbind(0) 20 ``` 21 22 Args: 23 rank: The rank of the decomposition matrices for updating weights in each attention layer. 24 block: The chosen attention blocks for implementing lora. 25 """ 26 def __init__(self, rank: int, block: nn.Module): 27 super().__init__() 28 self.qkv_proj = block.attn.qkv 29 self.dim = self.qkv_proj.in_features 30 31 self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) 32 self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) 33 self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) 34 self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) 35 36 self.reset_parameters() 37 38 block.attn.qkv = self 39 40 def reset_parameters(self): 41 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 42 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 43 nn.init.zeros_(self.w_b_linear_q.weight) 44 nn.init.zeros_(self.w_b_linear_v.weight) 45 46 def forward(self, x): 47 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 48 new_q = self.w_b_linear_q(self.w_a_linear_q(x)) 49 new_v = self.w_b_linear_v(self.w_a_linear_v(x)) 50 qkv[:, :, :, :self.dim] += new_q 51 qkv[:, :, :, -self.dim:] += new_v 52 return qkv 53 54 55class FacTSurgery(nn.Module): 56 """Operates on the attention layers for performing factorized attention. 57 58 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 59 60 Args: 61 rank: The rank of the decomposition matrices for updating weights in each attention layer. 62 block: The chosen attention blocks for implementing fact. 63 dropout: The dropout rate for the factorized attention. 64 """ 65 def __init__( 66 self, 67 rank: int, 68 block: nn.Module, 69 dropout: Optional[float] = 0.1, 70 ): 71 super().__init__() 72 self.qkv_proj = block.attn.qkv 73 self.dim = self.qkv_proj.in_features 74 75 self.q_FacTs = nn.Linear(rank, rank, bias=False) 76 self.v_FacTs = nn.Linear(rank, rank, bias=False) 77 78 self.dropout = dropout 79 if self.dropout is not None: 80 self.dp_q = nn.Dropout(self.dropout) 81 self.dp_v = nn.Dropout(self.dropout) 82 83 self.FacTu = nn.Linear(self.dim, rank, bias=False) 84 self.FacTv = nn.Linear(rank, self.dim, bias=False) 85 86 block.attn.qkv = self 87 88 def forward(self, x): 89 qkv = self.qkv_proj(x) 90 91 new_q = self.q_FacTs(self.FacTu(x)) 92 new_v = self.v_FacTs(self.FacTu(x)) 93 94 if self.dropout is not None: 95 new_q = self.dp_q(new_q) 96 new_v = self.dp_v(new_v) 97 98 new_q = self.FacTv(new_q) 99 new_v = self.FacTv(new_v) 100 101 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 102 qkv[:, :, :, : self.dim] += new_q 103 qkv[:, :, :, -self.dim:] += new_v 104 105 return qkv 106 107 108class SelectiveSurgery(nn.Module): 109 """Base class for selectively allowing gradient updates for certain parameters. 110 """ 111 def __init__(self, block: nn.Module): 112 super().__init__() 113 self.block = block 114 115 def allow_gradient_update_for_parameters( 116 self, 117 prefix: Optional[List[str]] = None, 118 suffix: Optional[List[str]] = None, 119 infix: Optional[List[str]] = None, 120 ): 121 """This function decides the parameter attributes to match for allowing gradient updates. 122 123 Args: 124 prefix: Matches the part of parameter name in front. 125 suffix: Matches the part of parameter name at the end. 126 infix: Matches parts of parameter name occuring in between. 127 """ 128 for k, v in self.block.named_parameters(): 129 if prefix is not None and k.startswith(tuple(prefix)): 130 v.requires_grad = True 131 132 if suffix is not None and k.endswith(tuple(suffix)): 133 v.requires_grad = True 134 135 if infix is not None: 136 for per_infix in infix: 137 if k.find(per_infix) != -1: 138 v.requires_grad = True 139 140 def forward(self, x): 141 return x 142 143 144class AttentionSurgery(SelectiveSurgery): 145 """Child class for allowing gradient updates for parameters in attention layers. 146 """ 147 def __init__(self, block: nn.Module): 148 super().__init__(block=block) 149 # Allow gradient updates for the attention layers in the image encoder. 150 self.allow_gradient_update_for_parameters(prefix=["attn"]) 151 152 153class BiasSurgery(SelectiveSurgery): 154 """Child class for allowing gradient updates for bias parameters. 155 """ 156 def __init__(self, block: nn.Module): 157 super().__init__(block=block) 158 # Allow gradient updates for the bias parameters in the image encoder. 159 self.allow_gradient_update_for_parameters(suffix=["bias"]) 160 161 162class LayerNormSurgery(SelectiveSurgery): 163 """Child class for allowing gradient updates in normalization layers. 164 """ 165 def __init__(self, block: nn.Module): 166 super().__init__(block=block) 167 # Allow gradient updates for the LayerNorm parameters in the image encoder. 168 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) 169 170 171class PEFT_Sam(nn.Module): 172 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 173 174 Args: 175 model: The Segment Anything model. 176 rank: The rank for low-rank adaptation. 177 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 178 attention_layers_to_update: Which specific layers we apply PEFT methods to. 179 """ 180 181 def __init__( 182 self, 183 model: Sam, 184 rank: int, 185 peft_module: nn.Module = LoRASurgery, 186 attention_layers_to_update: Union[List[int]] = None, 187 **module_kwargs 188 ): 189 super().__init__() 190 191 assert rank > 0 192 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." 193 194 if attention_layers_to_update: 195 self.peft_layers = attention_layers_to_update 196 else: # Applies PEFT to the image encoder by default 197 self.peft_layers = list(range(len(model.image_encoder.blocks))) 198 199 self.peft_module = peft_module 200 self.peft_blocks = [] 201 202 # let's freeze all the pretrained image encoder layers first 203 for param in model.image_encoder.parameters(): 204 param.requires_grad = False 205 206 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 207 # If we only want specific layers with PEFT instead of all 208 if t_layer_i not in self.peft_layers: 209 continue 210 211 if issubclass(self.peft_module, SelectiveSurgery): 212 peft_block = self.peft_module(block=blk) 213 else: 214 peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) 215 216 self.peft_blocks.append(peft_block) 217 218 self.peft_blocks = nn.ModuleList(self.peft_blocks) 219 220 self.sam = model 221 222 def forward(self, batched_input, multimask_output): 223 return self.sam(batched_input, multimask_output)
10class LoRASurgery(nn.Module): 11 """Operates on the attention layers for performing low-rank adaptation. 12 13 (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) 14 15 In SAM, it is implemented as: 16 ```python 17 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 18 B, N, C = x.shape 19 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 20 q, k, v = qkv.unbind(0) 21 ``` 22 23 Args: 24 rank: The rank of the decomposition matrices for updating weights in each attention layer. 25 block: The chosen attention blocks for implementing lora. 26 """ 27 def __init__(self, rank: int, block: nn.Module): 28 super().__init__() 29 self.qkv_proj = block.attn.qkv 30 self.dim = self.qkv_proj.in_features 31 32 self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) 33 self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) 34 self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) 35 self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) 36 37 self.reset_parameters() 38 39 block.attn.qkv = self 40 41 def reset_parameters(self): 42 nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) 43 nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) 44 nn.init.zeros_(self.w_b_linear_q.weight) 45 nn.init.zeros_(self.w_b_linear_v.weight) 46 47 def forward(self, x): 48 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 49 new_q = self.w_b_linear_q(self.w_a_linear_q(x)) 50 new_v = self.w_b_linear_v(self.w_a_linear_v(x)) 51 qkv[:, :, :, :self.dim] += new_q 52 qkv[:, :, :, -self.dim:] += new_v 53 return qkv
Operates on the attention layers for performing low-rank adaptation.
(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
In SAM, it is implemented as:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing lora.
27 def __init__(self, rank: int, block: nn.Module): 28 super().__init__() 29 self.qkv_proj = block.attn.qkv 30 self.dim = self.qkv_proj.in_features 31 32 self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) 33 self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) 34 self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) 35 self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) 36 37 self.reset_parameters() 38 39 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
47 def forward(self, x): 48 qkv = self.qkv_proj(x) # B, N, N, 3 * org_C 49 new_q = self.w_b_linear_q(self.w_a_linear_q(x)) 50 new_v = self.w_b_linear_v(self.w_a_linear_v(x)) 51 qkv[:, :, :, :self.dim] += new_q 52 qkv[:, :, :, -self.dim:] += new_v 53 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
56class FacTSurgery(nn.Module): 57 """Operates on the attention layers for performing factorized attention. 58 59 (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) 60 61 Args: 62 rank: The rank of the decomposition matrices for updating weights in each attention layer. 63 block: The chosen attention blocks for implementing fact. 64 dropout: The dropout rate for the factorized attention. 65 """ 66 def __init__( 67 self, 68 rank: int, 69 block: nn.Module, 70 dropout: Optional[float] = 0.1, 71 ): 72 super().__init__() 73 self.qkv_proj = block.attn.qkv 74 self.dim = self.qkv_proj.in_features 75 76 self.q_FacTs = nn.Linear(rank, rank, bias=False) 77 self.v_FacTs = nn.Linear(rank, rank, bias=False) 78 79 self.dropout = dropout 80 if self.dropout is not None: 81 self.dp_q = nn.Dropout(self.dropout) 82 self.dp_v = nn.Dropout(self.dropout) 83 84 self.FacTu = nn.Linear(self.dim, rank, bias=False) 85 self.FacTv = nn.Linear(rank, self.dim, bias=False) 86 87 block.attn.qkv = self 88 89 def forward(self, x): 90 qkv = self.qkv_proj(x) 91 92 new_q = self.q_FacTs(self.FacTu(x)) 93 new_v = self.v_FacTs(self.FacTu(x)) 94 95 if self.dropout is not None: 96 new_q = self.dp_q(new_q) 97 new_v = self.dp_v(new_v) 98 99 new_q = self.FacTv(new_q) 100 new_v = self.FacTv(new_v) 101 102 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 103 qkv[:, :, :, : self.dim] += new_q 104 qkv[:, :, :, -self.dim:] += new_v 105 106 return qkv
Operates on the attention layers for performing factorized attention.
(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
Arguments:
- rank: The rank of the decomposition matrices for updating weights in each attention layer.
- block: The chosen attention blocks for implementing fact.
- dropout: The dropout rate for the factorized attention.
66 def __init__( 67 self, 68 rank: int, 69 block: nn.Module, 70 dropout: Optional[float] = 0.1, 71 ): 72 super().__init__() 73 self.qkv_proj = block.attn.qkv 74 self.dim = self.qkv_proj.in_features 75 76 self.q_FacTs = nn.Linear(rank, rank, bias=False) 77 self.v_FacTs = nn.Linear(rank, rank, bias=False) 78 79 self.dropout = dropout 80 if self.dropout is not None: 81 self.dp_q = nn.Dropout(self.dropout) 82 self.dp_v = nn.Dropout(self.dropout) 83 84 self.FacTu = nn.Linear(self.dim, rank, bias=False) 85 self.FacTv = nn.Linear(rank, self.dim, bias=False) 86 87 block.attn.qkv = self
Initialize internal Module state, shared by both nn.Module and ScriptModule.
89 def forward(self, x): 90 qkv = self.qkv_proj(x) 91 92 new_q = self.q_FacTs(self.FacTu(x)) 93 new_v = self.v_FacTs(self.FacTu(x)) 94 95 if self.dropout is not None: 96 new_q = self.dp_q(new_q) 97 new_v = self.dp_v(new_v) 98 99 new_q = self.FacTv(new_q) 100 new_v = self.FacTv(new_v) 101 102 # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate 103 qkv[:, :, :, : self.dim] += new_q 104 qkv[:, :, :, -self.dim:] += new_v 105 106 return qkv
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
109class SelectiveSurgery(nn.Module): 110 """Base class for selectively allowing gradient updates for certain parameters. 111 """ 112 def __init__(self, block: nn.Module): 113 super().__init__() 114 self.block = block 115 116 def allow_gradient_update_for_parameters( 117 self, 118 prefix: Optional[List[str]] = None, 119 suffix: Optional[List[str]] = None, 120 infix: Optional[List[str]] = None, 121 ): 122 """This function decides the parameter attributes to match for allowing gradient updates. 123 124 Args: 125 prefix: Matches the part of parameter name in front. 126 suffix: Matches the part of parameter name at the end. 127 infix: Matches parts of parameter name occuring in between. 128 """ 129 for k, v in self.block.named_parameters(): 130 if prefix is not None and k.startswith(tuple(prefix)): 131 v.requires_grad = True 132 133 if suffix is not None and k.endswith(tuple(suffix)): 134 v.requires_grad = True 135 136 if infix is not None: 137 for per_infix in infix: 138 if k.find(per_infix) != -1: 139 v.requires_grad = True 140 141 def forward(self, x): 142 return x
Base class for selectively allowing gradient updates for certain parameters.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
116 def allow_gradient_update_for_parameters( 117 self, 118 prefix: Optional[List[str]] = None, 119 suffix: Optional[List[str]] = None, 120 infix: Optional[List[str]] = None, 121 ): 122 """This function decides the parameter attributes to match for allowing gradient updates. 123 124 Args: 125 prefix: Matches the part of parameter name in front. 126 suffix: Matches the part of parameter name at the end. 127 infix: Matches parts of parameter name occuring in between. 128 """ 129 for k, v in self.block.named_parameters(): 130 if prefix is not None and k.startswith(tuple(prefix)): 131 v.requires_grad = True 132 133 if suffix is not None and k.endswith(tuple(suffix)): 134 v.requires_grad = True 135 136 if infix is not None: 137 for per_infix in infix: 138 if k.find(per_infix) != -1: 139 v.requires_grad = True
This function decides the parameter attributes to match for allowing gradient updates.
Arguments:
- prefix: Matches the part of parameter name in front.
- suffix: Matches the part of parameter name at the end.
- infix: Matches parts of parameter name occuring in between.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
145class AttentionSurgery(SelectiveSurgery): 146 """Child class for allowing gradient updates for parameters in attention layers. 147 """ 148 def __init__(self, block: nn.Module): 149 super().__init__(block=block) 150 # Allow gradient updates for the attention layers in the image encoder. 151 self.allow_gradient_update_for_parameters(prefix=["attn"])
Child class for allowing gradient updates for parameters in attention layers.
148 def __init__(self, block: nn.Module): 149 super().__init__(block=block) 150 # Allow gradient updates for the attention layers in the image encoder. 151 self.allow_gradient_update_for_parameters(prefix=["attn"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
154class BiasSurgery(SelectiveSurgery): 155 """Child class for allowing gradient updates for bias parameters. 156 """ 157 def __init__(self, block: nn.Module): 158 super().__init__(block=block) 159 # Allow gradient updates for the bias parameters in the image encoder. 160 self.allow_gradient_update_for_parameters(suffix=["bias"])
Child class for allowing gradient updates for bias parameters.
157 def __init__(self, block: nn.Module): 158 super().__init__(block=block) 159 # Allow gradient updates for the bias parameters in the image encoder. 160 self.allow_gradient_update_for_parameters(suffix=["bias"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
163class LayerNormSurgery(SelectiveSurgery): 164 """Child class for allowing gradient updates in normalization layers. 165 """ 166 def __init__(self, block: nn.Module): 167 super().__init__(block=block) 168 # Allow gradient updates for the LayerNorm parameters in the image encoder. 169 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Child class for allowing gradient updates in normalization layers.
166 def __init__(self, block: nn.Module): 167 super().__init__(block=block) 168 # Allow gradient updates for the LayerNorm parameters in the image encoder. 169 self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
172class PEFT_Sam(nn.Module): 173 """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. 174 175 Args: 176 model: The Segment Anything model. 177 rank: The rank for low-rank adaptation. 178 peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. 179 attention_layers_to_update: Which specific layers we apply PEFT methods to. 180 """ 181 182 def __init__( 183 self, 184 model: Sam, 185 rank: int, 186 peft_module: nn.Module = LoRASurgery, 187 attention_layers_to_update: Union[List[int]] = None, 188 **module_kwargs 189 ): 190 super().__init__() 191 192 assert rank > 0 193 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." 194 195 if attention_layers_to_update: 196 self.peft_layers = attention_layers_to_update 197 else: # Applies PEFT to the image encoder by default 198 self.peft_layers = list(range(len(model.image_encoder.blocks))) 199 200 self.peft_module = peft_module 201 self.peft_blocks = [] 202 203 # let's freeze all the pretrained image encoder layers first 204 for param in model.image_encoder.parameters(): 205 param.requires_grad = False 206 207 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 208 # If we only want specific layers with PEFT instead of all 209 if t_layer_i not in self.peft_layers: 210 continue 211 212 if issubclass(self.peft_module, SelectiveSurgery): 213 peft_block = self.peft_module(block=blk) 214 else: 215 peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) 216 217 self.peft_blocks.append(peft_block) 218 219 self.peft_blocks = nn.ModuleList(self.peft_blocks) 220 221 self.sam = model 222 223 def forward(self, batched_input, multimask_output): 224 return self.sam(batched_input, multimask_output)
Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Arguments:
- model: The Segment Anything model.
- rank: The rank for low-rank adaptation.
- peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
- attention_layers_to_update: Which specific layers we apply PEFT methods to.
182 def __init__( 183 self, 184 model: Sam, 185 rank: int, 186 peft_module: nn.Module = LoRASurgery, 187 attention_layers_to_update: Union[List[int]] = None, 188 **module_kwargs 189 ): 190 super().__init__() 191 192 assert rank > 0 193 assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." 194 195 if attention_layers_to_update: 196 self.peft_layers = attention_layers_to_update 197 else: # Applies PEFT to the image encoder by default 198 self.peft_layers = list(range(len(model.image_encoder.blocks))) 199 200 self.peft_module = peft_module 201 self.peft_blocks = [] 202 203 # let's freeze all the pretrained image encoder layers first 204 for param in model.image_encoder.parameters(): 205 param.requires_grad = False 206 207 for t_layer_i, blk in enumerate(model.image_encoder.blocks): 208 # If we only want specific layers with PEFT instead of all 209 if t_layer_i not in self.peft_layers: 210 continue 211 212 if issubclass(self.peft_module, SelectiveSurgery): 213 peft_block = self.peft_module(block=blk) 214 else: 215 peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) 216 217 self.peft_blocks.append(peft_block) 218 219 self.peft_blocks = nn.ModuleList(self.peft_blocks) 220 221 self.sam = model
Initialize internal Module state, shared by both nn.Module and ScriptModule.
223 def forward(self, batched_input, multimask_output): 224 return self.sam(batched_input, multimask_output)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile