micro_sam.models.sam_3d_wrapper
1from typing import Any, List, Dict, Type 2 3import torch 4import torch.nn as nn 5 6from segment_anything.modeling import Sam 7from segment_anything.modeling.image_encoder import window_partition, window_unpartition 8 9from ..util import get_sam_model 10from .peft_sam import LoRASurgery 11 12 13def get_sam_3d_model( 14 device, 15 n_classes, 16 image_size, 17 lora_rank=None, 18 freeze_encoder=False, 19 model_type="vit_b", 20 checkpoint_path=None, 21): 22 if lora_rank is None: 23 peft_kwargs = {} 24 else: 25 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 26 27 _, sam = get_sam_model( 28 model_type=model_type, 29 device=device, 30 checkpoint_path=checkpoint_path, 31 return_sam=True, 32 flexible_load_checkpoint=True, 33 num_multimask_outputs=n_classes, 34 image_size=image_size, 35 peft_kwargs=peft_kwargs, 36 ) 37 38 # Make sure not to freeze the encoder when using LoRA. 39 freeze_encoder_ = freeze_encoder if lora_rank is None else False 40 sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) 41 sam_3d.to(device) 42 return sam_3d 43 44 45class Sam3DWrapper(nn.Module): 46 def __init__(self, sam_model: Sam, freeze_encoder: bool): 47 """Initializes the Sam3DWrapper object. 48 49 Args: 50 sam_model: The Sam model to be wrapped. 51 freeze_encoder: Whether to freeze the image encoder. 52 """ 53 super().__init__() 54 sam_model.image_encoder = ImageEncoderViT3DWrapper( 55 image_encoder=sam_model.image_encoder 56 ) 57 self.sam_model = sam_model 58 59 self.freeze_encoder = freeze_encoder 60 if self.freeze_encoder: 61 for param in self.sam_model.image_encoder.parameters(): 62 param.requires_grad = False 63 64 def forward( 65 self, 66 batched_input: List[Dict[str, Any]], 67 multimask_output: bool 68 ) -> List[Dict[str, torch.Tensor]]: 69 """Predict 3D masks for the current inputs. 70 71 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 72 73 Args: 74 batched_input: A list over input images, each a dictionary with the following keys. 75 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 76 'original_size': The original size of the image (HxW) before transformation. 77 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 78 79 Returns: 80 A list over input images, where each element is as dictionary with the following keys: 81 'masks': Mask prediction for this object. 82 'iou_predictions': IOU score prediction for this object. 83 'low_res_masks': Low resolution mask prediction for this object. 84 """ 85 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 86 original_size = batched_input[0]["original_size"] 87 assert all(inp["original_size"] == original_size for inp in batched_input) 88 89 # dimensions: [b, 3, d, h, w] 90 shape = batched_images.shape 91 assert shape[1] == 3 92 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 93 # Transpose the axes, so that the depth axis is the first axis and the channel 94 # axis is the second axis. This is expected by the transformer! 95 batched_images = batched_images.transpose(1, 2) 96 assert batched_images.shape[1] == d_size 97 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 98 99 input_images = self.sam_model.preprocess(batched_images) 100 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 101 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 102 points=None, boxes=None, masks=None 103 ) 104 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 105 image_embeddings=image_embeddings, 106 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 107 sparse_prompt_embeddings=sparse_embeddings, 108 dense_prompt_embeddings=dense_embeddings, 109 multimask_output=multimask_output 110 ) 111 masks = self.sam_model.postprocess_masks( 112 low_res_masks, 113 input_size=batched_images.shape[-2:], 114 original_size=original_size, 115 ) 116 117 # Bring the masks and low-res masks into the correct shape: 118 # - disentangle batches and z-slices 119 # - rearrange output channels and z-slices 120 121 n_channels = masks.shape[1] 122 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 123 low_res_masks = low_res_masks.view( 124 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 125 ) 126 127 masks = masks.transpose(1, 2) 128 low_res_masks = low_res_masks.transpose(1, 2) 129 130 # Make the output compatable with the SAM output. 131 outputs = [{ 132 "masks": mask.unsqueeze(0), 133 "iou_predictions": iou_pred, 134 "low_res_logits": low_res_mask.unsqueeze(0) 135 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 136 return outputs 137 138 139class ImageEncoderViT3DWrapper(nn.Module): 140 def __init__( 141 self, 142 image_encoder: nn.Module, 143 num_heads: int = 12, 144 embed_dim: int = 768, 145 ): 146 super().__init__() 147 self.image_encoder = image_encoder 148 self.img_size = self.image_encoder.img_size 149 150 # replace default blocks with 3d adapter blocks 151 for i, blk in enumerate(self.image_encoder.blocks): 152 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim) 153 154 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 155 x = self.image_encoder.patch_embed(x) 156 if self.image_encoder.pos_embed is not None: 157 x = x + self.image_encoder.pos_embed 158 159 for blk in self.image_encoder.blocks: 160 x = blk(x, d_size) 161 162 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 163 164 return x 165 166 167class NDBlockWrapper(nn.Module): 168 def __init__( 169 self, 170 block: nn.Module, 171 dim: int, 172 num_heads: int, 173 norm_layer: Type[nn.Module] = nn.LayerNorm, 174 adapter_channels: int = 384, 175 ): 176 super().__init__() 177 self.block = block 178 179 self.adapter_channels = adapter_channels 180 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 181 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 182 self.adapter_conv = nn.Conv3d( 183 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 184 ) 185 self.adapter_act = nn.GELU() 186 self.adapter_norm = norm_layer(dim) 187 188 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 189 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 190 self.adapter_conv_2 = nn.Conv3d( 191 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 192 ) 193 self.adapter_act_2 = nn.GELU() 194 self.adapter_norm_2 = norm_layer(dim) 195 196 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 197 b_size, hw_size = x.shape[0], x.shape[1] 198 199 # 3D adapter 200 shortcut = x 201 x = self.adapter_norm(x) 202 x = self.adapter_linear_down(x) 203 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 204 x = torch.permute(x, (0, -1, 1, 2, 3)) 205 x = self.adapter_conv(x) 206 x = torch.permute(x, (0, 2, 3, 4, 1)) 207 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 208 x = self.adapter_act(x) 209 x = self.adapter_linear_up(x) 210 x = shortcut + x 211 # end 3D adapter 212 213 shortcut = x 214 x = self.block.norm1(x) 215 # Window partition 216 if self.block.window_size > 0: 217 H, W = x.shape[1], x.shape[2] 218 x, pad_hw = window_partition(x, self.block.window_size) 219 220 x = self.block.attn(x) 221 # Reverse window partition 222 if self.block.window_size > 0: 223 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 224 225 x = shortcut + x 226 227 # 3D adapter 228 shortcut = x 229 x = self.adapter_norm_2(x) 230 x = self.adapter_linear_down_2(x) 231 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 232 x = torch.permute(x, (0, -1, 1, 2, 3)) 233 x = self.adapter_conv_2(x) 234 x = torch.permute(x, (0, 2, 3, 4, 1)) 235 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 236 x = self.adapter_act_2(x) 237 x = self.adapter_linear_up_2(x) 238 x = shortcut + x 239 # end 3D adapter 240 241 x = x + self.block.mlp(self.block.norm2(x)) 242 243 return x
14def get_sam_3d_model( 15 device, 16 n_classes, 17 image_size, 18 lora_rank=None, 19 freeze_encoder=False, 20 model_type="vit_b", 21 checkpoint_path=None, 22): 23 if lora_rank is None: 24 peft_kwargs = {} 25 else: 26 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 27 28 _, sam = get_sam_model( 29 model_type=model_type, 30 device=device, 31 checkpoint_path=checkpoint_path, 32 return_sam=True, 33 flexible_load_checkpoint=True, 34 num_multimask_outputs=n_classes, 35 image_size=image_size, 36 peft_kwargs=peft_kwargs, 37 ) 38 39 # Make sure not to freeze the encoder when using LoRA. 40 freeze_encoder_ = freeze_encoder if lora_rank is None else False 41 sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) 42 sam_3d.to(device) 43 return sam_3d
46class Sam3DWrapper(nn.Module): 47 def __init__(self, sam_model: Sam, freeze_encoder: bool): 48 """Initializes the Sam3DWrapper object. 49 50 Args: 51 sam_model: The Sam model to be wrapped. 52 freeze_encoder: Whether to freeze the image encoder. 53 """ 54 super().__init__() 55 sam_model.image_encoder = ImageEncoderViT3DWrapper( 56 image_encoder=sam_model.image_encoder 57 ) 58 self.sam_model = sam_model 59 60 self.freeze_encoder = freeze_encoder 61 if self.freeze_encoder: 62 for param in self.sam_model.image_encoder.parameters(): 63 param.requires_grad = False 64 65 def forward( 66 self, 67 batched_input: List[Dict[str, Any]], 68 multimask_output: bool 69 ) -> List[Dict[str, torch.Tensor]]: 70 """Predict 3D masks for the current inputs. 71 72 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 73 74 Args: 75 batched_input: A list over input images, each a dictionary with the following keys. 76 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 77 'original_size': The original size of the image (HxW) before transformation. 78 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 79 80 Returns: 81 A list over input images, where each element is as dictionary with the following keys: 82 'masks': Mask prediction for this object. 83 'iou_predictions': IOU score prediction for this object. 84 'low_res_masks': Low resolution mask prediction for this object. 85 """ 86 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 87 original_size = batched_input[0]["original_size"] 88 assert all(inp["original_size"] == original_size for inp in batched_input) 89 90 # dimensions: [b, 3, d, h, w] 91 shape = batched_images.shape 92 assert shape[1] == 3 93 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 94 # Transpose the axes, so that the depth axis is the first axis and the channel 95 # axis is the second axis. This is expected by the transformer! 96 batched_images = batched_images.transpose(1, 2) 97 assert batched_images.shape[1] == d_size 98 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 99 100 input_images = self.sam_model.preprocess(batched_images) 101 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 102 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 103 points=None, boxes=None, masks=None 104 ) 105 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 106 image_embeddings=image_embeddings, 107 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 108 sparse_prompt_embeddings=sparse_embeddings, 109 dense_prompt_embeddings=dense_embeddings, 110 multimask_output=multimask_output 111 ) 112 masks = self.sam_model.postprocess_masks( 113 low_res_masks, 114 input_size=batched_images.shape[-2:], 115 original_size=original_size, 116 ) 117 118 # Bring the masks and low-res masks into the correct shape: 119 # - disentangle batches and z-slices 120 # - rearrange output channels and z-slices 121 122 n_channels = masks.shape[1] 123 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 124 low_res_masks = low_res_masks.view( 125 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 126 ) 127 128 masks = masks.transpose(1, 2) 129 low_res_masks = low_res_masks.transpose(1, 2) 130 131 # Make the output compatable with the SAM output. 132 outputs = [{ 133 "masks": mask.unsqueeze(0), 134 "iou_predictions": iou_pred, 135 "low_res_logits": low_res_mask.unsqueeze(0) 136 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 137 return outputs
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
47 def __init__(self, sam_model: Sam, freeze_encoder: bool): 48 """Initializes the Sam3DWrapper object. 49 50 Args: 51 sam_model: The Sam model to be wrapped. 52 freeze_encoder: Whether to freeze the image encoder. 53 """ 54 super().__init__() 55 sam_model.image_encoder = ImageEncoderViT3DWrapper( 56 image_encoder=sam_model.image_encoder 57 ) 58 self.sam_model = sam_model 59 60 self.freeze_encoder = freeze_encoder 61 if self.freeze_encoder: 62 for param in self.sam_model.image_encoder.parameters(): 63 param.requires_grad = False
Initializes the Sam3DWrapper object.
Arguments:
- sam_model: The Sam model to be wrapped.
- freeze_encoder: Whether to freeze the image encoder.
65 def forward( 66 self, 67 batched_input: List[Dict[str, Any]], 68 multimask_output: bool 69 ) -> List[Dict[str, torch.Tensor]]: 70 """Predict 3D masks for the current inputs. 71 72 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 73 74 Args: 75 batched_input: A list over input images, each a dictionary with the following keys. 76 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 77 'original_size': The original size of the image (HxW) before transformation. 78 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 79 80 Returns: 81 A list over input images, where each element is as dictionary with the following keys: 82 'masks': Mask prediction for this object. 83 'iou_predictions': IOU score prediction for this object. 84 'low_res_masks': Low resolution mask prediction for this object. 85 """ 86 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 87 original_size = batched_input[0]["original_size"] 88 assert all(inp["original_size"] == original_size for inp in batched_input) 89 90 # dimensions: [b, 3, d, h, w] 91 shape = batched_images.shape 92 assert shape[1] == 3 93 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 94 # Transpose the axes, so that the depth axis is the first axis and the channel 95 # axis is the second axis. This is expected by the transformer! 96 batched_images = batched_images.transpose(1, 2) 97 assert batched_images.shape[1] == d_size 98 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 99 100 input_images = self.sam_model.preprocess(batched_images) 101 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 102 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 103 points=None, boxes=None, masks=None 104 ) 105 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 106 image_embeddings=image_embeddings, 107 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 108 sparse_prompt_embeddings=sparse_embeddings, 109 dense_prompt_embeddings=dense_embeddings, 110 multimask_output=multimask_output 111 ) 112 masks = self.sam_model.postprocess_masks( 113 low_res_masks, 114 input_size=batched_images.shape[-2:], 115 original_size=original_size, 116 ) 117 118 # Bring the masks and low-res masks into the correct shape: 119 # - disentangle batches and z-slices 120 # - rearrange output channels and z-slices 121 122 n_channels = masks.shape[1] 123 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 124 low_res_masks = low_res_masks.view( 125 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 126 ) 127 128 masks = masks.transpose(1, 2) 129 low_res_masks = low_res_masks.transpose(1, 2) 130 131 # Make the output compatable with the SAM output. 132 outputs = [{ 133 "masks": mask.unsqueeze(0), 134 "iou_predictions": iou_pred, 135 "low_res_logits": low_res_mask.unsqueeze(0) 136 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 137 return outputs
Predict 3D masks for the current inputs.
Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Arguments:
- batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation.
- multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:
A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object. 'iou_predictions': IOU score prediction for this object. 'low_res_masks': Low resolution mask prediction for this object.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
140class ImageEncoderViT3DWrapper(nn.Module): 141 def __init__( 142 self, 143 image_encoder: nn.Module, 144 num_heads: int = 12, 145 embed_dim: int = 768, 146 ): 147 super().__init__() 148 self.image_encoder = image_encoder 149 self.img_size = self.image_encoder.img_size 150 151 # replace default blocks with 3d adapter blocks 152 for i, blk in enumerate(self.image_encoder.blocks): 153 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim) 154 155 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 156 x = self.image_encoder.patch_embed(x) 157 if self.image_encoder.pos_embed is not None: 158 x = x + self.image_encoder.pos_embed 159 160 for blk in self.image_encoder.blocks: 161 x = blk(x, d_size) 162 163 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 164 165 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
141 def __init__( 142 self, 143 image_encoder: nn.Module, 144 num_heads: int = 12, 145 embed_dim: int = 768, 146 ): 147 super().__init__() 148 self.image_encoder = image_encoder 149 self.img_size = self.image_encoder.img_size 150 151 # replace default blocks with 3d adapter blocks 152 for i, blk in enumerate(self.image_encoder.blocks): 153 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
155 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 156 x = self.image_encoder.patch_embed(x) 157 if self.image_encoder.pos_embed is not None: 158 x = x + self.image_encoder.pos_embed 159 160 for blk in self.image_encoder.blocks: 161 x = blk(x, d_size) 162 163 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 164 165 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
168class NDBlockWrapper(nn.Module): 169 def __init__( 170 self, 171 block: nn.Module, 172 dim: int, 173 num_heads: int, 174 norm_layer: Type[nn.Module] = nn.LayerNorm, 175 adapter_channels: int = 384, 176 ): 177 super().__init__() 178 self.block = block 179 180 self.adapter_channels = adapter_channels 181 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 182 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 183 self.adapter_conv = nn.Conv3d( 184 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 185 ) 186 self.adapter_act = nn.GELU() 187 self.adapter_norm = norm_layer(dim) 188 189 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 190 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 191 self.adapter_conv_2 = nn.Conv3d( 192 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 193 ) 194 self.adapter_act_2 = nn.GELU() 195 self.adapter_norm_2 = norm_layer(dim) 196 197 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 198 b_size, hw_size = x.shape[0], x.shape[1] 199 200 # 3D adapter 201 shortcut = x 202 x = self.adapter_norm(x) 203 x = self.adapter_linear_down(x) 204 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 205 x = torch.permute(x, (0, -1, 1, 2, 3)) 206 x = self.adapter_conv(x) 207 x = torch.permute(x, (0, 2, 3, 4, 1)) 208 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 209 x = self.adapter_act(x) 210 x = self.adapter_linear_up(x) 211 x = shortcut + x 212 # end 3D adapter 213 214 shortcut = x 215 x = self.block.norm1(x) 216 # Window partition 217 if self.block.window_size > 0: 218 H, W = x.shape[1], x.shape[2] 219 x, pad_hw = window_partition(x, self.block.window_size) 220 221 x = self.block.attn(x) 222 # Reverse window partition 223 if self.block.window_size > 0: 224 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 225 226 x = shortcut + x 227 228 # 3D adapter 229 shortcut = x 230 x = self.adapter_norm_2(x) 231 x = self.adapter_linear_down_2(x) 232 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 233 x = torch.permute(x, (0, -1, 1, 2, 3)) 234 x = self.adapter_conv_2(x) 235 x = torch.permute(x, (0, 2, 3, 4, 1)) 236 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 237 x = self.adapter_act_2(x) 238 x = self.adapter_linear_up_2(x) 239 x = shortcut + x 240 # end 3D adapter 241 242 x = x + self.block.mlp(self.block.norm2(x)) 243 244 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
169 def __init__( 170 self, 171 block: nn.Module, 172 dim: int, 173 num_heads: int, 174 norm_layer: Type[nn.Module] = nn.LayerNorm, 175 adapter_channels: int = 384, 176 ): 177 super().__init__() 178 self.block = block 179 180 self.adapter_channels = adapter_channels 181 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 182 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 183 self.adapter_conv = nn.Conv3d( 184 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 185 ) 186 self.adapter_act = nn.GELU() 187 self.adapter_norm = norm_layer(dim) 188 189 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 190 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 191 self.adapter_conv_2 = nn.Conv3d( 192 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 193 ) 194 self.adapter_act_2 = nn.GELU() 195 self.adapter_norm_2 = norm_layer(dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
197 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 198 b_size, hw_size = x.shape[0], x.shape[1] 199 200 # 3D adapter 201 shortcut = x 202 x = self.adapter_norm(x) 203 x = self.adapter_linear_down(x) 204 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 205 x = torch.permute(x, (0, -1, 1, 2, 3)) 206 x = self.adapter_conv(x) 207 x = torch.permute(x, (0, 2, 3, 4, 1)) 208 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 209 x = self.adapter_act(x) 210 x = self.adapter_linear_up(x) 211 x = shortcut + x 212 # end 3D adapter 213 214 shortcut = x 215 x = self.block.norm1(x) 216 # Window partition 217 if self.block.window_size > 0: 218 H, W = x.shape[1], x.shape[2] 219 x, pad_hw = window_partition(x, self.block.window_size) 220 221 x = self.block.attn(x) 222 # Reverse window partition 223 if self.block.window_size > 0: 224 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 225 226 x = shortcut + x 227 228 # 3D adapter 229 shortcut = x 230 x = self.adapter_norm_2(x) 231 x = self.adapter_linear_down_2(x) 232 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 233 x = torch.permute(x, (0, -1, 1, 2, 3)) 234 x = self.adapter_conv_2(x) 235 x = torch.permute(x, (0, 2, 3, 4, 1)) 236 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 237 x = self.adapter_act_2(x) 238 x = self.adapter_linear_up_2(x) 239 x = shortcut + x 240 # end 3D adapter 241 242 x = x + self.block.mlp(self.block.norm2(x)) 243 244 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile