micro_sam.models.peft_sam

  1import math
  2from typing import List, Union, Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from segment_anything.modeling import Sam
  8
  9
 10class LoRASurgery(nn.Module):
 11    """Operates on the attention layers for performing low-rank adaptation.
 12
 13    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
 14
 15    In SAM, it is implemented as:
 16    ```python
 17    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
 18    B, N, C = x.shape
 19    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
 20    q, k, v = qkv.unbind(0)
 21    ```
 22
 23    Args:
 24        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 25        block: The chosen attention blocks for implementing lora.
 26    """
 27    def __init__(self, rank: int, block: nn.Module):
 28        super().__init__()
 29        self.qkv_proj = block.attn.qkv
 30        self.dim = self.qkv_proj.in_features
 31        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
 32        self.rank = rank
 33
 34        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
 35        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
 36        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
 37        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
 38
 39        self.reset_parameters()
 40
 41        block.attn.qkv = self
 42
 43    def reset_parameters(self):
 44        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
 45        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
 46        nn.init.zeros_(self.w_b_linear_q.weight)
 47        nn.init.zeros_(self.w_b_linear_v.weight)
 48
 49    def forward(self, x):
 50        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
 51        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
 52        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
 53        qkv[:, :, :, :self.dim] += new_q
 54        qkv[:, :, :, -self.dim:] += new_v
 55        return qkv
 56
 57
 58class FacTSurgery(nn.Module):
 59    """Operates on the attention layers for performing factorized attention.
 60
 61    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 62
 63    Args:
 64        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 65        block: The chosen attention blocks for implementing fact.
 66        dropout: The dropout rate for the factorized attention.
 67    """
 68    def __init__(
 69        self,
 70        rank: int,
 71        block: nn.Module,
 72        dropout: Optional[float] = 0.1,
 73    ):
 74        super().__init__()
 75        self.qkv_proj = block.attn.qkv
 76        self.dim = self.qkv_proj.in_features
 77
 78        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 79        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 80
 81        self.dropout = dropout
 82        if self.dropout is not None:
 83            self.dp_q = nn.Dropout(self.dropout)
 84            self.dp_v = nn.Dropout(self.dropout)
 85
 86        self.FacTu = nn.Linear(self.dim, rank, bias=False)
 87        self.FacTv = nn.Linear(rank, self.dim, bias=False)
 88
 89        block.attn.qkv = self
 90
 91    def forward(self, x):
 92        qkv = self.qkv_proj(x)
 93
 94        new_q = self.q_FacTs(self.FacTu(x))
 95        new_v = self.v_FacTs(self.FacTu(x))
 96
 97        if self.dropout is not None:
 98            new_q = self.dp_q(new_q)
 99            new_v = self.dp_v(new_v)
100
101        new_q = self.FacTv(new_q)
102        new_v = self.FacTv(new_v)
103
104        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
105        qkv[:, :, :, : self.dim] += new_q
106        qkv[:, :, :, -self.dim:] += new_v
107
108        return qkv
109
110
111class ScaleShiftLayer(nn.Module):
112    def __init__(self, layer, dim):
113        super().__init__()
114        self.layer = layer
115        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
116        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
117        layer = self
118
119    def forward(self, x):
120        x = self.layer(x)
121        assert self.scale.shape == self.shift.shape
122        if x.shape[-1] == self.scale.shape[0]:
123            return x * self.scale + self.shift
124        elif x.shape[1] == self.scale.shape[0]:
125            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
126        else:
127            raise ValueError('Input tensors do not match the shape of the scale factors.')
128
129
130class SSFSurgery(nn.Module):
131    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
132
133    Args:
134        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
135        block: The chosen attention blocks for implementing ssf.
136        dim: The input dimensions determining the shape of scale and shift parameters.
137    """
138    def __init__(self, rank: int, block: nn.Module):
139        super().__init__()
140        self.block = block
141
142        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
143        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
144            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
145            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
146            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
147            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
148            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
149            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
150
151        # If we get the embedding block, add one ScaleShiftLayer
152        elif hasattr(block, "patch_embed"):
153            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
154
155    def forward(self, x):
156        return x
157
158
159class SelectiveSurgery(nn.Module):
160    """Base class for selectively allowing gradient updates for certain parameters.
161    """
162    def __init__(self, block: nn.Module):
163        super().__init__()
164        self.block = block
165
166    def allow_gradient_update_for_parameters(
167        self,
168        prefix: Optional[List[str]] = None,
169        suffix: Optional[List[str]] = None,
170        infix: Optional[List[str]] = None,
171    ):
172        """This function decides the parameter attributes to match for allowing gradient updates.
173
174        Args:
175            prefix: Matches the part of parameter name in front.
176            suffix: Matches the part of parameter name at the end.
177            infix: Matches parts of parameter name occuring in between.
178        """
179        for k, v in self.block.named_parameters():
180            if prefix is not None and k.startswith(tuple(prefix)):
181                v.requires_grad = True
182
183            if suffix is not None and k.endswith(tuple(suffix)):
184                v.requires_grad = True
185
186            if infix is not None:
187                for per_infix in infix:
188                    if k.find(per_infix) != -1:
189                        v.requires_grad = True
190
191    def forward(self, x):
192        return x
193
194
195class AdaptFormer(nn.Module):
196    """Adds AdaptFormer Module in place of the MLP Layers
197
198    Args:
199        rank: The rank is not used in this class but kept here for consistency.
200        block: The chosen encoder block for implementing AdaptFormer.
201        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
202        dropout: The dropout rate for the dropout layer between down and up projection layer.
203        projection_size: The size of the projection layer.
204    """
205    def __init__(
206        self,
207        rank: int,
208        block: nn.Module,
209        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
210        dropout: Optional[float] = None,  # Does not have an obvious advantage.
211        projection_size: int = 64,  # Stable choice from our preliminary exp.
212    ):
213        super().__init__()
214
215        self.mlp_proj = block.mlp
216        self.n_embd = block.mlp.lin1.in_features
217
218        if alpha == 'learnable_scalar':
219            self.alpha = nn.Parameter(torch.ones(1))
220        else:
221            self.alpha = alpha
222
223        self.projection_size = projection_size
224        self.dropout = dropout
225
226        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
227        self.non_linear_func = nn.ReLU()
228        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
229
230        block.mlp = self
231
232        if self.dropout is not None:
233            self.dropout_layer = nn.Dropout(self.dropout)
234
235        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
236        nn.init.zeros_(self.up_proj.weight)
237        nn.init.zeros_(self.down_proj.bias)
238        nn.init.zeros_(self.up_proj.bias)
239
240    def forward(self, x):
241        residual = x
242        mlp_output = self.mlp_proj(x)
243
244        down = self.down_proj(x)
245        down = self.non_linear_func(down)
246
247        if self.dropout is not None:
248            down = self.dropout_layer(down)
249
250        up = self.up_proj(down)
251        up = up * self.alpha
252        output = up + residual + mlp_output
253
254        return output
255
256
257class AttentionSurgery(SelectiveSurgery):
258    """Child class for allowing gradient updates for parameters in attention layers.
259    """
260    def __init__(self, block: nn.Module):
261        super().__init__(block=block)
262        # Allow gradient updates for the attention layers in the image encoder.
263        self.allow_gradient_update_for_parameters(prefix=["attn"])
264
265
266class BiasSurgery(SelectiveSurgery):
267    """Child class for allowing gradient updates for bias parameters.
268    """
269    def __init__(self, block: nn.Module):
270        super().__init__(block=block)
271        # Allow gradient updates for the bias parameters in the image encoder.
272        self.allow_gradient_update_for_parameters(suffix=["bias"])
273
274
275class LayerNormSurgery(SelectiveSurgery):
276    """Child class for allowing gradient updates in normalization layers.
277    """
278    def __init__(self, block: nn.Module):
279        super().__init__(block=block)
280        # Allow gradient updates for the LayerNorm parameters in the image encoder.
281        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])
282
283
284class PEFT_Sam(nn.Module):
285    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
286
287    Args:
288        model: The Segment Anything model.
289        rank: The rank for low-rank adaptation.
290        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
291        attention_layers_to_update: Which specific layers we apply PEFT methods to.
292    """
293
294    def __init__(
295        self,
296        model: Sam,
297        rank: int,
298        peft_module: nn.Module = LoRASurgery,
299        attention_layers_to_update: Union[List[int]] = None,
300        **module_kwargs
301    ):
302        super().__init__()
303
304        assert rank > 0
305
306        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
307            "Invalid PEFT module"
308        )
309
310        if attention_layers_to_update:
311            self.peft_layers = attention_layers_to_update
312        else:   # Applies PEFT to the image encoder by default
313            self.peft_layers = list(range(len(model.image_encoder.blocks)))
314
315        self.peft_module = peft_module
316        self.peft_blocks = []
317
318        # let's freeze all the pretrained image encoder layers first
319        for param in model.image_encoder.parameters():
320            param.requires_grad = False
321
322        # Add scale and shift parameters to the patch embedding layers.
323        if issubclass(self.peft_module, SSFSurgery):
324            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
325
326        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
327            # If we only want specific layers with PEFT instead of all
328            if t_layer_i not in self.peft_layers:
329                continue
330
331            if issubclass(self.peft_module, SelectiveSurgery):
332                self.peft_blocks.append(self.peft_module(block=blk))
333            else:
334                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
335
336        self.peft_blocks = nn.ModuleList(self.peft_blocks)
337
338        self.sam = model
339
340    def forward(self, batched_input, multimask_output):
341        return self.sam(batched_input, multimask_output)
class LoRASurgery(torch.nn.modules.module.Module):
11class LoRASurgery(nn.Module):
12    """Operates on the attention layers for performing low-rank adaptation.
13
14    (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)
15
16    In SAM, it is implemented as:
17    ```python
18    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
19    B, N, C = x.shape
20    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
21    q, k, v = qkv.unbind(0)
22    ```
23
24    Args:
25        rank: The rank of the decomposition matrices for updating weights in each attention layer.
26        block: The chosen attention blocks for implementing lora.
27    """
28    def __init__(self, rank: int, block: nn.Module):
29        super().__init__()
30        self.qkv_proj = block.attn.qkv
31        self.dim = self.qkv_proj.in_features
32        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
33        self.rank = rank
34
35        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
36        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
37        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
38        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
39
40        self.reset_parameters()
41
42        block.attn.qkv = self
43
44    def reset_parameters(self):
45        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
46        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
47        nn.init.zeros_(self.w_b_linear_q.weight)
48        nn.init.zeros_(self.w_b_linear_v.weight)
49
50    def forward(self, x):
51        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
52        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
53        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
54        qkv[:, :, :, :self.dim] += new_q
55        qkv[:, :, :, -self.dim:] += new_v
56        return qkv

Operates on the attention layers for performing low-rank adaptation.

(Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/)

In SAM, it is implemented as:

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing lora.
LoRASurgery(rank: int, block: torch.nn.modules.module.Module)
28    def __init__(self, rank: int, block: nn.Module):
29        super().__init__()
30        self.qkv_proj = block.attn.qkv
31        self.dim = self.qkv_proj.in_features
32        self.alpha = 1  # From our experiments, 'alpha' as 1 gives the best performance.
33        self.rank = rank
34
35        self.w_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
36        self.w_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
37        self.w_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
38        self.w_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
39
40        self.reset_parameters()
41
42        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
alpha
rank
w_a_linear_q
w_b_linear_q
w_a_linear_v
w_b_linear_v
def reset_parameters(self):
44    def reset_parameters(self):
45        nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5))
46        nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5))
47        nn.init.zeros_(self.w_b_linear_q.weight)
48        nn.init.zeros_(self.w_b_linear_v.weight)
def forward(self, x):
50    def forward(self, x):
51        qkv = self.qkv_proj(x)  # B, N, N, 3 * org_C
52        new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
53        new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))
54        qkv[:, :, :, :self.dim] += new_q
55        qkv[:, :, :, -self.dim:] += new_v
56        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class FacTSurgery(torch.nn.modules.module.Module):
 59class FacTSurgery(nn.Module):
 60    """Operates on the attention layers for performing factorized attention.
 61
 62    (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)
 63
 64    Args:
 65        rank: The rank of the decomposition matrices for updating weights in each attention layer.
 66        block: The chosen attention blocks for implementing fact.
 67        dropout: The dropout rate for the factorized attention.
 68    """
 69    def __init__(
 70        self,
 71        rank: int,
 72        block: nn.Module,
 73        dropout: Optional[float] = 0.1,
 74    ):
 75        super().__init__()
 76        self.qkv_proj = block.attn.qkv
 77        self.dim = self.qkv_proj.in_features
 78
 79        self.q_FacTs = nn.Linear(rank, rank, bias=False)
 80        self.v_FacTs = nn.Linear(rank, rank, bias=False)
 81
 82        self.dropout = dropout
 83        if self.dropout is not None:
 84            self.dp_q = nn.Dropout(self.dropout)
 85            self.dp_v = nn.Dropout(self.dropout)
 86
 87        self.FacTu = nn.Linear(self.dim, rank, bias=False)
 88        self.FacTv = nn.Linear(rank, self.dim, bias=False)
 89
 90        block.attn.qkv = self
 91
 92    def forward(self, x):
 93        qkv = self.qkv_proj(x)
 94
 95        new_q = self.q_FacTs(self.FacTu(x))
 96        new_v = self.v_FacTs(self.FacTu(x))
 97
 98        if self.dropout is not None:
 99            new_q = self.dp_q(new_q)
100            new_v = self.dp_v(new_v)
101
102        new_q = self.FacTv(new_q)
103        new_v = self.FacTv(new_v)
104
105        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
106        qkv[:, :, :, : self.dim] += new_q
107        qkv[:, :, :, -self.dim:] += new_v
108
109        return qkv

Operates on the attention layers for performing factorized attention.

(Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py)

Arguments:
  • rank: The rank of the decomposition matrices for updating weights in each attention layer.
  • block: The chosen attention blocks for implementing fact.
  • dropout: The dropout rate for the factorized attention.
FacTSurgery( rank: int, block: torch.nn.modules.module.Module, dropout: Optional[float] = 0.1)
69    def __init__(
70        self,
71        rank: int,
72        block: nn.Module,
73        dropout: Optional[float] = 0.1,
74    ):
75        super().__init__()
76        self.qkv_proj = block.attn.qkv
77        self.dim = self.qkv_proj.in_features
78
79        self.q_FacTs = nn.Linear(rank, rank, bias=False)
80        self.v_FacTs = nn.Linear(rank, rank, bias=False)
81
82        self.dropout = dropout
83        if self.dropout is not None:
84            self.dp_q = nn.Dropout(self.dropout)
85            self.dp_v = nn.Dropout(self.dropout)
86
87        self.FacTu = nn.Linear(self.dim, rank, bias=False)
88        self.FacTv = nn.Linear(rank, self.dim, bias=False)
89
90        block.attn.qkv = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

qkv_proj
dim
q_FacTs
v_FacTs
dropout
FacTu
FacTv
def forward(self, x):
 92    def forward(self, x):
 93        qkv = self.qkv_proj(x)
 94
 95        new_q = self.q_FacTs(self.FacTu(x))
 96        new_v = self.v_FacTs(self.FacTu(x))
 97
 98        if self.dropout is not None:
 99            new_q = self.dp_q(new_q)
100            new_v = self.dp_v(new_v)
101
102        new_q = self.FacTv(new_q)
103        new_v = self.FacTv(new_v)
104
105        # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate
106        qkv[:, :, :, : self.dim] += new_q
107        qkv[:, :, :, -self.dim:] += new_v
108
109        return qkv

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ScaleShiftLayer(torch.nn.modules.module.Module):
112class ScaleShiftLayer(nn.Module):
113    def __init__(self, layer, dim):
114        super().__init__()
115        self.layer = layer
116        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
117        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
118        layer = self
119
120    def forward(self, x):
121        x = self.layer(x)
122        assert self.scale.shape == self.shift.shape
123        if x.shape[-1] == self.scale.shape[0]:
124            return x * self.scale + self.shift
125        elif x.shape[1] == self.scale.shape[0]:
126            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
127        else:
128            raise ValueError('Input tensors do not match the shape of the scale factors.')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ScaleShiftLayer(layer, dim)
113    def __init__(self, layer, dim):
114        super().__init__()
115        self.layer = layer
116        self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
117        self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
118        layer = self

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer
scale
shift
def forward(self, x):
120    def forward(self, x):
121        x = self.layer(x)
122        assert self.scale.shape == self.shift.shape
123        if x.shape[-1] == self.scale.shape[0]:
124            return x * self.scale + self.shift
125        elif x.shape[1] == self.scale.shape[0]:
126            return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
127        else:
128            raise ValueError('Input tensors do not match the shape of the scale factors.')

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SSFSurgery(torch.nn.modules.module.Module):
131class SSFSurgery(nn.Module):
132    """Operates on all layers in the transformer block for adding learnable scale and shift parameters.
133
134    Args:
135        rank: This parameter is not used in `SSFSurgery`. This is kept here for consistency.
136        block: The chosen attention blocks for implementing ssf.
137        dim: The input dimensions determining the shape of scale and shift parameters.
138    """
139    def __init__(self, rank: int, block: nn.Module):
140        super().__init__()
141        self.block = block
142
143        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
144        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
145            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
146            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
147            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
148            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
149            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
150            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
151
152        # If we get the embedding block, add one ScaleShiftLayer
153        elif hasattr(block, "patch_embed"):
154            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)
155
156    def forward(self, x):
157        return x

Operates on all layers in the transformer block for adding learnable scale and shift parameters.

Arguments:
  • rank: This parameter is not used in SSFSurgery. This is kept here for consistency.
  • block: The chosen attention blocks for implementing ssf.
  • dim: The input dimensions determining the shape of scale and shift parameters.
SSFSurgery(rank: int, block: torch.nn.modules.module.Module)
139    def __init__(self, rank: int, block: nn.Module):
140        super().__init__()
141        self.block = block
142
143        # If we get a transformer block (w. multiple sub-layers), we perform surgery on each layer.
144        if hasattr(block, "attn"):  # the minimum assumption is to verify the attention layers.
145            block.attn.qkv = ScaleShiftLayer(block.attn.qkv, block.attn.qkv.in_features*3)
146            block.attn.proj = ScaleShiftLayer(block.attn.proj, block.attn.proj.in_features)
147            block.mlp.lin1 = ScaleShiftLayer(block.mlp.lin1, block.mlp.lin1.out_features)
148            block.mlp.lin2 = ScaleShiftLayer(block.mlp.lin2, block.mlp.lin2.out_features)
149            block.norm1 = ScaleShiftLayer(block.norm1, block.norm1.normalized_shape[0])
150            block.norm2 = ScaleShiftLayer(block.norm2, block.norm2.normalized_shape[0])
151
152        # If we get the embedding block, add one ScaleShiftLayer
153        elif hasattr(block, "patch_embed"):
154            block.proj = ScaleShiftLayer(block.proj, block.proj.out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def forward(self, x):
156    def forward(self, x):
157        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SelectiveSurgery(torch.nn.modules.module.Module):
160class SelectiveSurgery(nn.Module):
161    """Base class for selectively allowing gradient updates for certain parameters.
162    """
163    def __init__(self, block: nn.Module):
164        super().__init__()
165        self.block = block
166
167    def allow_gradient_update_for_parameters(
168        self,
169        prefix: Optional[List[str]] = None,
170        suffix: Optional[List[str]] = None,
171        infix: Optional[List[str]] = None,
172    ):
173        """This function decides the parameter attributes to match for allowing gradient updates.
174
175        Args:
176            prefix: Matches the part of parameter name in front.
177            suffix: Matches the part of parameter name at the end.
178            infix: Matches parts of parameter name occuring in between.
179        """
180        for k, v in self.block.named_parameters():
181            if prefix is not None and k.startswith(tuple(prefix)):
182                v.requires_grad = True
183
184            if suffix is not None and k.endswith(tuple(suffix)):
185                v.requires_grad = True
186
187            if infix is not None:
188                for per_infix in infix:
189                    if k.find(per_infix) != -1:
190                        v.requires_grad = True
191
192    def forward(self, x):
193        return x

Base class for selectively allowing gradient updates for certain parameters.

SelectiveSurgery(block: torch.nn.modules.module.Module)
163    def __init__(self, block: nn.Module):
164        super().__init__()
165        self.block = block

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
def allow_gradient_update_for_parameters( self, prefix: Optional[List[str]] = None, suffix: Optional[List[str]] = None, infix: Optional[List[str]] = None):
167    def allow_gradient_update_for_parameters(
168        self,
169        prefix: Optional[List[str]] = None,
170        suffix: Optional[List[str]] = None,
171        infix: Optional[List[str]] = None,
172    ):
173        """This function decides the parameter attributes to match for allowing gradient updates.
174
175        Args:
176            prefix: Matches the part of parameter name in front.
177            suffix: Matches the part of parameter name at the end.
178            infix: Matches parts of parameter name occuring in between.
179        """
180        for k, v in self.block.named_parameters():
181            if prefix is not None and k.startswith(tuple(prefix)):
182                v.requires_grad = True
183
184            if suffix is not None and k.endswith(tuple(suffix)):
185                v.requires_grad = True
186
187            if infix is not None:
188                for per_infix in infix:
189                    if k.find(per_infix) != -1:
190                        v.requires_grad = True

This function decides the parameter attributes to match for allowing gradient updates.

Arguments:
  • prefix: Matches the part of parameter name in front.
  • suffix: Matches the part of parameter name at the end.
  • infix: Matches parts of parameter name occuring in between.
def forward(self, x):
192    def forward(self, x):
193        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AdaptFormer(torch.nn.modules.module.Module):
196class AdaptFormer(nn.Module):
197    """Adds AdaptFormer Module in place of the MLP Layers
198
199    Args:
200        rank: The rank is not used in this class but kept here for consistency.
201        block: The chosen encoder block for implementing AdaptFormer.
202        alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
203        dropout: The dropout rate for the dropout layer between down and up projection layer.
204        projection_size: The size of the projection layer.
205    """
206    def __init__(
207        self,
208        rank: int,
209        block: nn.Module,
210        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
211        dropout: Optional[float] = None,  # Does not have an obvious advantage.
212        projection_size: int = 64,  # Stable choice from our preliminary exp.
213    ):
214        super().__init__()
215
216        self.mlp_proj = block.mlp
217        self.n_embd = block.mlp.lin1.in_features
218
219        if alpha == 'learnable_scalar':
220            self.alpha = nn.Parameter(torch.ones(1))
221        else:
222            self.alpha = alpha
223
224        self.projection_size = projection_size
225        self.dropout = dropout
226
227        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
228        self.non_linear_func = nn.ReLU()
229        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
230
231        block.mlp = self
232
233        if self.dropout is not None:
234            self.dropout_layer = nn.Dropout(self.dropout)
235
236        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
237        nn.init.zeros_(self.up_proj.weight)
238        nn.init.zeros_(self.down_proj.bias)
239        nn.init.zeros_(self.up_proj.bias)
240
241    def forward(self, x):
242        residual = x
243        mlp_output = self.mlp_proj(x)
244
245        down = self.down_proj(x)
246        down = self.non_linear_func(down)
247
248        if self.dropout is not None:
249            down = self.dropout_layer(down)
250
251        up = self.up_proj(down)
252        up = up * self.alpha
253        output = up + residual + mlp_output
254
255        return output

Adds AdaptFormer Module in place of the MLP Layers

Arguments:
  • rank: The rank is not used in this class but kept here for consistency.
  • block: The chosen encoder block for implementing AdaptFormer.
  • alpha: A parameters that scales the Adapter path. Can be either learnable or some fixed value.
  • dropout: The dropout rate for the dropout layer between down and up projection layer.
  • projection_size: The size of the projection layer.
AdaptFormer( rank: int, block: torch.nn.modules.module.Module, alpha: Union[str, float, NoneType] = 'learnable_scalar', dropout: Optional[float] = None, projection_size: int = 64)
206    def __init__(
207        self,
208        rank: int,
209        block: nn.Module,
210        alpha: Optional[Union[str, float]] = "learnable_scalar",  # Stable choice from our preliminary exp.
211        dropout: Optional[float] = None,  # Does not have an obvious advantage.
212        projection_size: int = 64,  # Stable choice from our preliminary exp.
213    ):
214        super().__init__()
215
216        self.mlp_proj = block.mlp
217        self.n_embd = block.mlp.lin1.in_features
218
219        if alpha == 'learnable_scalar':
220            self.alpha = nn.Parameter(torch.ones(1))
221        else:
222            self.alpha = alpha
223
224        self.projection_size = projection_size
225        self.dropout = dropout
226
227        self.down_proj = nn.Linear(self.n_embd, self.projection_size)
228        self.non_linear_func = nn.ReLU()
229        self.up_proj = nn.Linear(self.projection_size, self.n_embd)
230
231        block.mlp = self
232
233        if self.dropout is not None:
234            self.dropout_layer = nn.Dropout(self.dropout)
235
236        nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
237        nn.init.zeros_(self.up_proj.weight)
238        nn.init.zeros_(self.down_proj.bias)
239        nn.init.zeros_(self.up_proj.bias)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

mlp_proj
n_embd
projection_size
dropout
down_proj
non_linear_func
up_proj
def forward(self, x):
241    def forward(self, x):
242        residual = x
243        mlp_output = self.mlp_proj(x)
244
245        down = self.down_proj(x)
246        down = self.non_linear_func(down)
247
248        if self.dropout is not None:
249            down = self.dropout_layer(down)
250
251        up = self.up_proj(down)
252        up = up * self.alpha
253        output = up + residual + mlp_output
254
255        return output

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class AttentionSurgery(SelectiveSurgery):
258class AttentionSurgery(SelectiveSurgery):
259    """Child class for allowing gradient updates for parameters in attention layers.
260    """
261    def __init__(self, block: nn.Module):
262        super().__init__(block=block)
263        # Allow gradient updates for the attention layers in the image encoder.
264        self.allow_gradient_update_for_parameters(prefix=["attn"])

Child class for allowing gradient updates for parameters in attention layers.

AttentionSurgery(block: torch.nn.modules.module.Module)
261    def __init__(self, block: nn.Module):
262        super().__init__(block=block)
263        # Allow gradient updates for the attention layers in the image encoder.
264        self.allow_gradient_update_for_parameters(prefix=["attn"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class BiasSurgery(SelectiveSurgery):
267class BiasSurgery(SelectiveSurgery):
268    """Child class for allowing gradient updates for bias parameters.
269    """
270    def __init__(self, block: nn.Module):
271        super().__init__(block=block)
272        # Allow gradient updates for the bias parameters in the image encoder.
273        self.allow_gradient_update_for_parameters(suffix=["bias"])

Child class for allowing gradient updates for bias parameters.

BiasSurgery(block: torch.nn.modules.module.Module)
270    def __init__(self, block: nn.Module):
271        super().__init__(block=block)
272        # Allow gradient updates for the bias parameters in the image encoder.
273        self.allow_gradient_update_for_parameters(suffix=["bias"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class LayerNormSurgery(SelectiveSurgery):
276class LayerNormSurgery(SelectiveSurgery):
277    """Child class for allowing gradient updates in normalization layers.
278    """
279    def __init__(self, block: nn.Module):
280        super().__init__(block=block)
281        # Allow gradient updates for the LayerNorm parameters in the image encoder.
282        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Child class for allowing gradient updates in normalization layers.

LayerNormSurgery(block: torch.nn.modules.module.Module)
279    def __init__(self, block: nn.Module):
280        super().__init__(block=block)
281        # Allow gradient updates for the LayerNorm parameters in the image encoder.
282        self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Inherited Members
SelectiveSurgery
block
allow_gradient_update_for_parameters
forward
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class PEFT_Sam(torch.nn.modules.module.Module):
285class PEFT_Sam(nn.Module):
286    """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
287
288    Args:
289        model: The Segment Anything model.
290        rank: The rank for low-rank adaptation.
291        peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
292        attention_layers_to_update: Which specific layers we apply PEFT methods to.
293    """
294
295    def __init__(
296        self,
297        model: Sam,
298        rank: int,
299        peft_module: nn.Module = LoRASurgery,
300        attention_layers_to_update: Union[List[int]] = None,
301        **module_kwargs
302    ):
303        super().__init__()
304
305        assert rank > 0
306
307        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
308            "Invalid PEFT module"
309        )
310
311        if attention_layers_to_update:
312            self.peft_layers = attention_layers_to_update
313        else:   # Applies PEFT to the image encoder by default
314            self.peft_layers = list(range(len(model.image_encoder.blocks)))
315
316        self.peft_module = peft_module
317        self.peft_blocks = []
318
319        # let's freeze all the pretrained image encoder layers first
320        for param in model.image_encoder.parameters():
321            param.requires_grad = False
322
323        # Add scale and shift parameters to the patch embedding layers.
324        if issubclass(self.peft_module, SSFSurgery):
325            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
326
327        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
328            # If we only want specific layers with PEFT instead of all
329            if t_layer_i not in self.peft_layers:
330                continue
331
332            if issubclass(self.peft_module, SelectiveSurgery):
333                self.peft_blocks.append(self.peft_module(block=blk))
334            else:
335                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
336
337        self.peft_blocks = nn.ModuleList(self.peft_blocks)
338
339        self.sam = model
340
341    def forward(self, batched_input, multimask_output):
342        return self.sam(batched_input, multimask_output)

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Arguments:
  • model: The Segment Anything model.
  • rank: The rank for low-rank adaptation.
  • peft_module: Wrapper to operate on the image encoder blocks for the PEFT method.
  • attention_layers_to_update: Which specific layers we apply PEFT methods to.
PEFT_Sam( model: segment_anything.modeling.sam.Sam, rank: int, peft_module: torch.nn.modules.module.Module = <class 'LoRASurgery'>, attention_layers_to_update: List[int] = None, **module_kwargs)
295    def __init__(
296        self,
297        model: Sam,
298        rank: int,
299        peft_module: nn.Module = LoRASurgery,
300        attention_layers_to_update: Union[List[int]] = None,
301        **module_kwargs
302    ):
303        super().__init__()
304
305        assert rank > 0
306
307        assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery, AdaptFormer]), (
308            "Invalid PEFT module"
309        )
310
311        if attention_layers_to_update:
312            self.peft_layers = attention_layers_to_update
313        else:   # Applies PEFT to the image encoder by default
314            self.peft_layers = list(range(len(model.image_encoder.blocks)))
315
316        self.peft_module = peft_module
317        self.peft_blocks = []
318
319        # let's freeze all the pretrained image encoder layers first
320        for param in model.image_encoder.parameters():
321            param.requires_grad = False
322
323        # Add scale and shift parameters to the patch embedding layers.
324        if issubclass(self.peft_module, SSFSurgery):
325            self.peft_blocks.append(self.peft_module(rank=rank, block=model.image_encoder.patch_embed))
326
327        for t_layer_i, blk in enumerate(model.image_encoder.blocks):
328            # If we only want specific layers with PEFT instead of all
329            if t_layer_i not in self.peft_layers:
330                continue
331
332            if issubclass(self.peft_module, SelectiveSurgery):
333                self.peft_blocks.append(self.peft_module(block=blk))
334            else:
335                self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))
336
337        self.peft_blocks = nn.ModuleList(self.peft_blocks)
338
339        self.sam = model

Initialize internal Module state, shared by both nn.Module and ScriptModule.

peft_module
peft_blocks
sam
def forward(self, batched_input, multimask_output):
341    def forward(self, batched_input, multimask_output):
342        return self.sam(batched_input, multimask_output)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile