micro_sam.models.sam_3d_wrapper
1import os 2from typing import Any, List, Dict, Type, Union, Optional 3 4import torch 5import torch.nn as nn 6 7from segment_anything.modeling import Sam 8from segment_anything.modeling.image_encoder import window_partition, window_unpartition 9 10from ..util import get_sam_model 11from .peft_sam import LoRASurgery 12 13 14def get_sam_3d_model( 15 device: Union[str, torch.device], 16 n_classes: int, 17 image_size: int, 18 lora_rank: Optional[int] = None, 19 freeze_encoder: bool = False, 20 model_type: str = "vit_b", 21 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 22): 23 if lora_rank is None: 24 peft_kwargs = {} 25 else: 26 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 27 28 _, sam = get_sam_model( 29 model_type=model_type, 30 device=device, 31 checkpoint_path=checkpoint_path, 32 return_sam=True, 33 flexible_load_checkpoint=True, 34 num_multimask_outputs=n_classes, 35 image_size=image_size, 36 peft_kwargs=peft_kwargs, 37 ) 38 39 # Make sure not to freeze the encoder when using LoRA. 40 _freeze_encoder = freeze_encoder if lora_rank is None else False 41 sam_3d = Sam3DWrapper(sam, freeze_encoder=_freeze_encoder, model_type=model_type) 42 sam_3d.to(device) 43 44 return sam_3d 45 46 47class Sam3DWrapper(nn.Module): 48 def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"): 49 """Initializes the Sam3DWrapper object. 50 51 Args: 52 sam_model: The Sam model to be wrapped. 53 freeze_encoder: Whether to freeze the image encoder. 54 model_type: The choice of segment anything model to wrap adapters for respective model configuration. 55 """ 56 super().__init__() 57 58 # Model configurations 59 if model_type == "vit_b": 60 embed_dim, num_heads = 768, 12 61 elif model_type == "vit_l": 62 embed_dim, num_heads = 1024, 16 63 elif model_type == "vit_h": 64 embed_dim, num_heads = 1280, 16 65 else: 66 raise ValueError(f"'{model_type}' is not a supported choice of model.") 67 68 sam_model.image_encoder = ImageEncoderViT3DWrapper( 69 image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim, 70 ) 71 self.sam_model = sam_model 72 73 self.freeze_encoder = freeze_encoder 74 if self.freeze_encoder: 75 for param in self.sam_model.image_encoder.parameters(): 76 param.requires_grad = False 77 78 def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]: 79 """Predict 3D masks for the current inputs. 80 81 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 82 83 Args: 84 batched_input: A list over input images, each a dictionary with the following keys. 85 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 86 'original_size': The original size of the image (HxW) before transformation. 87 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 88 89 Returns: 90 A list over input images, where each element is as dictionary with the following keys: 91 'masks': Mask prediction for this object. 92 'iou_predictions': IOU score prediction for this object. 93 'low_res_masks': Low resolution mask prediction for this object. 94 """ 95 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 96 original_size = batched_input[0]["original_size"] 97 assert all(inp["original_size"] == original_size for inp in batched_input) 98 99 # dimensions: [b, 3, d, h, w] 100 shape = batched_images.shape 101 assert shape[1] == 3 102 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 103 # Transpose the axes, so that the depth axis is the first axis and the channel 104 # axis is the second axis. This is expected by the transformer! 105 batched_images = batched_images.transpose(1, 2) 106 assert batched_images.shape[1] == d_size 107 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 108 109 input_images = self.sam_model.preprocess(batched_images) 110 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 111 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 112 points=None, boxes=None, masks=None 113 ) 114 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 115 image_embeddings=image_embeddings, 116 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 117 sparse_prompt_embeddings=sparse_embeddings, 118 dense_prompt_embeddings=dense_embeddings, 119 multimask_output=multimask_output 120 ) 121 masks = self.sam_model.postprocess_masks( 122 low_res_masks, 123 input_size=batched_images.shape[-2:], 124 original_size=original_size, 125 ) 126 127 # Bring the masks and low-res masks into the correct shape: 128 # - disentangle batches and z-slices 129 # - rearrange output channels and z-slices 130 131 n_channels = masks.shape[1] 132 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 133 low_res_masks = low_res_masks.view( 134 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 135 ) 136 137 masks = masks.transpose(1, 2) 138 low_res_masks = low_res_masks.transpose(1, 2) 139 140 # Make the output compatable with the SAM output. 141 outputs = [{ 142 "masks": mask.unsqueeze(0), 143 "iou_predictions": iou_pred, 144 "low_res_logits": low_res_mask.unsqueeze(0) 145 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 146 147 return outputs 148 149 150class ImageEncoderViT3DWrapper(nn.Module): 151 def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768): 152 153 super().__init__() 154 self.image_encoder = image_encoder 155 self.img_size = self.image_encoder.img_size 156 157 # Replace default blocks with 3d adapter blocks 158 for i, blk in enumerate(self.image_encoder.blocks): 159 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim) 160 161 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 162 x = self.image_encoder.patch_embed(x) 163 if self.image_encoder.pos_embed is not None: 164 x = x + self.image_encoder.pos_embed 165 166 for blk in self.image_encoder.blocks: 167 x = blk(x, d_size) 168 169 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 170 171 return x 172 173 174class NDBlockWrapper(nn.Module): 175 def __init__( 176 self, 177 block: nn.Module, 178 dim: int, 179 num_heads: int, 180 norm_layer: Type[nn.Module] = nn.LayerNorm, 181 adapter_channels: int = 384, 182 ): 183 super().__init__() 184 self.block = block 185 186 self.adapter_channels = adapter_channels 187 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 188 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 189 self.adapter_conv = nn.Conv3d( 190 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 191 ) 192 self.adapter_act = nn.GELU() 193 self.adapter_norm = norm_layer(dim) 194 195 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 196 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 197 self.adapter_conv_2 = nn.Conv3d( 198 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 199 ) 200 self.adapter_act_2 = nn.GELU() 201 self.adapter_norm_2 = norm_layer(dim) 202 203 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 204 b_size, hw_size = x.shape[0], x.shape[1] 205 206 # 3D adapter 207 shortcut = x 208 x = self.adapter_norm(x) 209 x = self.adapter_linear_down(x) 210 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 211 x = torch.permute(x, (0, -1, 1, 2, 3)) 212 x = self.adapter_conv(x) 213 x = torch.permute(x, (0, 2, 3, 4, 1)) 214 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 215 x = self.adapter_act(x) 216 x = self.adapter_linear_up(x) 217 x = shortcut + x 218 # end 3D adapter 219 220 shortcut = x 221 x = self.block.norm1(x) 222 # Window partition 223 if self.block.window_size > 0: 224 H, W = x.shape[1], x.shape[2] 225 x, pad_hw = window_partition(x, self.block.window_size) 226 227 x = self.block.attn(x) 228 # Reverse window partition 229 if self.block.window_size > 0: 230 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 231 232 x = shortcut + x 233 234 # 3D adapter 235 shortcut = x 236 x = self.adapter_norm_2(x) 237 x = self.adapter_linear_down_2(x) 238 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 239 x = torch.permute(x, (0, -1, 1, 2, 3)) 240 x = self.adapter_conv_2(x) 241 x = torch.permute(x, (0, 2, 3, 4, 1)) 242 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 243 x = self.adapter_act_2(x) 244 x = self.adapter_linear_up_2(x) 245 x = shortcut + x 246 # end 3D adapter 247 248 x = x + self.block.mlp(self.block.norm2(x)) 249 250 return x
15def get_sam_3d_model( 16 device: Union[str, torch.device], 17 n_classes: int, 18 image_size: int, 19 lora_rank: Optional[int] = None, 20 freeze_encoder: bool = False, 21 model_type: str = "vit_b", 22 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 23): 24 if lora_rank is None: 25 peft_kwargs = {} 26 else: 27 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 28 29 _, sam = get_sam_model( 30 model_type=model_type, 31 device=device, 32 checkpoint_path=checkpoint_path, 33 return_sam=True, 34 flexible_load_checkpoint=True, 35 num_multimask_outputs=n_classes, 36 image_size=image_size, 37 peft_kwargs=peft_kwargs, 38 ) 39 40 # Make sure not to freeze the encoder when using LoRA. 41 _freeze_encoder = freeze_encoder if lora_rank is None else False 42 sam_3d = Sam3DWrapper(sam, freeze_encoder=_freeze_encoder, model_type=model_type) 43 sam_3d.to(device) 44 45 return sam_3d
48class Sam3DWrapper(nn.Module): 49 def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"): 50 """Initializes the Sam3DWrapper object. 51 52 Args: 53 sam_model: The Sam model to be wrapped. 54 freeze_encoder: Whether to freeze the image encoder. 55 model_type: The choice of segment anything model to wrap adapters for respective model configuration. 56 """ 57 super().__init__() 58 59 # Model configurations 60 if model_type == "vit_b": 61 embed_dim, num_heads = 768, 12 62 elif model_type == "vit_l": 63 embed_dim, num_heads = 1024, 16 64 elif model_type == "vit_h": 65 embed_dim, num_heads = 1280, 16 66 else: 67 raise ValueError(f"'{model_type}' is not a supported choice of model.") 68 69 sam_model.image_encoder = ImageEncoderViT3DWrapper( 70 image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim, 71 ) 72 self.sam_model = sam_model 73 74 self.freeze_encoder = freeze_encoder 75 if self.freeze_encoder: 76 for param in self.sam_model.image_encoder.parameters(): 77 param.requires_grad = False 78 79 def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]: 80 """Predict 3D masks for the current inputs. 81 82 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 83 84 Args: 85 batched_input: A list over input images, each a dictionary with the following keys. 86 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 87 'original_size': The original size of the image (HxW) before transformation. 88 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 89 90 Returns: 91 A list over input images, where each element is as dictionary with the following keys: 92 'masks': Mask prediction for this object. 93 'iou_predictions': IOU score prediction for this object. 94 'low_res_masks': Low resolution mask prediction for this object. 95 """ 96 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 97 original_size = batched_input[0]["original_size"] 98 assert all(inp["original_size"] == original_size for inp in batched_input) 99 100 # dimensions: [b, 3, d, h, w] 101 shape = batched_images.shape 102 assert shape[1] == 3 103 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 104 # Transpose the axes, so that the depth axis is the first axis and the channel 105 # axis is the second axis. This is expected by the transformer! 106 batched_images = batched_images.transpose(1, 2) 107 assert batched_images.shape[1] == d_size 108 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 109 110 input_images = self.sam_model.preprocess(batched_images) 111 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 112 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 113 points=None, boxes=None, masks=None 114 ) 115 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 116 image_embeddings=image_embeddings, 117 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 118 sparse_prompt_embeddings=sparse_embeddings, 119 dense_prompt_embeddings=dense_embeddings, 120 multimask_output=multimask_output 121 ) 122 masks = self.sam_model.postprocess_masks( 123 low_res_masks, 124 input_size=batched_images.shape[-2:], 125 original_size=original_size, 126 ) 127 128 # Bring the masks and low-res masks into the correct shape: 129 # - disentangle batches and z-slices 130 # - rearrange output channels and z-slices 131 132 n_channels = masks.shape[1] 133 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 134 low_res_masks = low_res_masks.view( 135 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 136 ) 137 138 masks = masks.transpose(1, 2) 139 low_res_masks = low_res_masks.transpose(1, 2) 140 141 # Make the output compatable with the SAM output. 142 outputs = [{ 143 "masks": mask.unsqueeze(0), 144 "iou_predictions": iou_pred, 145 "low_res_logits": low_res_mask.unsqueeze(0) 146 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 147 148 return outputs
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
49 def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"): 50 """Initializes the Sam3DWrapper object. 51 52 Args: 53 sam_model: The Sam model to be wrapped. 54 freeze_encoder: Whether to freeze the image encoder. 55 model_type: The choice of segment anything model to wrap adapters for respective model configuration. 56 """ 57 super().__init__() 58 59 # Model configurations 60 if model_type == "vit_b": 61 embed_dim, num_heads = 768, 12 62 elif model_type == "vit_l": 63 embed_dim, num_heads = 1024, 16 64 elif model_type == "vit_h": 65 embed_dim, num_heads = 1280, 16 66 else: 67 raise ValueError(f"'{model_type}' is not a supported choice of model.") 68 69 sam_model.image_encoder = ImageEncoderViT3DWrapper( 70 image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim, 71 ) 72 self.sam_model = sam_model 73 74 self.freeze_encoder = freeze_encoder 75 if self.freeze_encoder: 76 for param in self.sam_model.image_encoder.parameters(): 77 param.requires_grad = False
Initializes the Sam3DWrapper object.
Arguments:
- sam_model: The Sam model to be wrapped.
- freeze_encoder: Whether to freeze the image encoder.
- model_type: The choice of segment anything model to wrap adapters for respective model configuration.
79 def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]: 80 """Predict 3D masks for the current inputs. 81 82 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 83 84 Args: 85 batched_input: A list over input images, each a dictionary with the following keys. 86 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 87 'original_size': The original size of the image (HxW) before transformation. 88 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 89 90 Returns: 91 A list over input images, where each element is as dictionary with the following keys: 92 'masks': Mask prediction for this object. 93 'iou_predictions': IOU score prediction for this object. 94 'low_res_masks': Low resolution mask prediction for this object. 95 """ 96 batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) 97 original_size = batched_input[0]["original_size"] 98 assert all(inp["original_size"] == original_size for inp in batched_input) 99 100 # dimensions: [b, 3, d, h, w] 101 shape = batched_images.shape 102 assert shape[1] == 3 103 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] 104 # Transpose the axes, so that the depth axis is the first axis and the channel 105 # axis is the second axis. This is expected by the transformer! 106 batched_images = batched_images.transpose(1, 2) 107 assert batched_images.shape[1] == d_size 108 batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) 109 110 input_images = self.sam_model.preprocess(batched_images) 111 image_embeddings = self.sam_model.image_encoder(input_images, d_size) 112 sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( 113 points=None, boxes=None, masks=None 114 ) 115 low_res_masks, iou_predictions = self.sam_model.mask_decoder( 116 image_embeddings=image_embeddings, 117 image_pe=self.sam_model.prompt_encoder.get_dense_pe(), 118 sparse_prompt_embeddings=sparse_embeddings, 119 dense_prompt_embeddings=dense_embeddings, 120 multimask_output=multimask_output 121 ) 122 masks = self.sam_model.postprocess_masks( 123 low_res_masks, 124 input_size=batched_images.shape[-2:], 125 original_size=original_size, 126 ) 127 128 # Bring the masks and low-res masks into the correct shape: 129 # - disentangle batches and z-slices 130 # - rearrange output channels and z-slices 131 132 n_channels = masks.shape[1] 133 masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) 134 low_res_masks = low_res_masks.view( 135 *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) 136 ) 137 138 masks = masks.transpose(1, 2) 139 low_res_masks = low_res_masks.transpose(1, 2) 140 141 # Make the output compatable with the SAM output. 142 outputs = [{ 143 "masks": mask.unsqueeze(0), 144 "iou_predictions": iou_pred, 145 "low_res_logits": low_res_mask.unsqueeze(0) 146 } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] 147 148 return outputs
Predict 3D masks for the current inputs.
Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Arguments:
- batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation.
- multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:
A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object. 'iou_predictions': IOU score prediction for this object. 'low_res_masks': Low resolution mask prediction for this object.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
151class ImageEncoderViT3DWrapper(nn.Module): 152 def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768): 153 154 super().__init__() 155 self.image_encoder = image_encoder 156 self.img_size = self.image_encoder.img_size 157 158 # Replace default blocks with 3d adapter blocks 159 for i, blk in enumerate(self.image_encoder.blocks): 160 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim) 161 162 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 163 x = self.image_encoder.patch_embed(x) 164 if self.image_encoder.pos_embed is not None: 165 x = x + self.image_encoder.pos_embed 166 167 for blk in self.image_encoder.blocks: 168 x = blk(x, d_size) 169 170 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 171 172 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
152 def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768): 153 154 super().__init__() 155 self.image_encoder = image_encoder 156 self.img_size = self.image_encoder.img_size 157 158 # Replace default blocks with 3d adapter blocks 159 for i, blk in enumerate(self.image_encoder.blocks): 160 self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
162 def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: 163 x = self.image_encoder.patch_embed(x) 164 if self.image_encoder.pos_embed is not None: 165 x = x + self.image_encoder.pos_embed 166 167 for blk in self.image_encoder.blocks: 168 x = blk(x, d_size) 169 170 x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) 171 172 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
175class NDBlockWrapper(nn.Module): 176 def __init__( 177 self, 178 block: nn.Module, 179 dim: int, 180 num_heads: int, 181 norm_layer: Type[nn.Module] = nn.LayerNorm, 182 adapter_channels: int = 384, 183 ): 184 super().__init__() 185 self.block = block 186 187 self.adapter_channels = adapter_channels 188 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 189 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 190 self.adapter_conv = nn.Conv3d( 191 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 192 ) 193 self.adapter_act = nn.GELU() 194 self.adapter_norm = norm_layer(dim) 195 196 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 197 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 198 self.adapter_conv_2 = nn.Conv3d( 199 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 200 ) 201 self.adapter_act_2 = nn.GELU() 202 self.adapter_norm_2 = norm_layer(dim) 203 204 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 205 b_size, hw_size = x.shape[0], x.shape[1] 206 207 # 3D adapter 208 shortcut = x 209 x = self.adapter_norm(x) 210 x = self.adapter_linear_down(x) 211 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 212 x = torch.permute(x, (0, -1, 1, 2, 3)) 213 x = self.adapter_conv(x) 214 x = torch.permute(x, (0, 2, 3, 4, 1)) 215 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 216 x = self.adapter_act(x) 217 x = self.adapter_linear_up(x) 218 x = shortcut + x 219 # end 3D adapter 220 221 shortcut = x 222 x = self.block.norm1(x) 223 # Window partition 224 if self.block.window_size > 0: 225 H, W = x.shape[1], x.shape[2] 226 x, pad_hw = window_partition(x, self.block.window_size) 227 228 x = self.block.attn(x) 229 # Reverse window partition 230 if self.block.window_size > 0: 231 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 232 233 x = shortcut + x 234 235 # 3D adapter 236 shortcut = x 237 x = self.adapter_norm_2(x) 238 x = self.adapter_linear_down_2(x) 239 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 240 x = torch.permute(x, (0, -1, 1, 2, 3)) 241 x = self.adapter_conv_2(x) 242 x = torch.permute(x, (0, 2, 3, 4, 1)) 243 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 244 x = self.adapter_act_2(x) 245 x = self.adapter_linear_up_2(x) 246 x = shortcut + x 247 # end 3D adapter 248 249 x = x + self.block.mlp(self.block.norm2(x)) 250 251 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
176 def __init__( 177 self, 178 block: nn.Module, 179 dim: int, 180 num_heads: int, 181 norm_layer: Type[nn.Module] = nn.LayerNorm, 182 adapter_channels: int = 384, 183 ): 184 super().__init__() 185 self.block = block 186 187 self.adapter_channels = adapter_channels 188 self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) 189 self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) 190 self.adapter_conv = nn.Conv3d( 191 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 192 ) 193 self.adapter_act = nn.GELU() 194 self.adapter_norm = norm_layer(dim) 195 196 self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) 197 self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) 198 self.adapter_conv_2 = nn.Conv3d( 199 self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" 200 ) 201 self.adapter_act_2 = nn.GELU() 202 self.adapter_norm_2 = norm_layer(dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
204 def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: 205 b_size, hw_size = x.shape[0], x.shape[1] 206 207 # 3D adapter 208 shortcut = x 209 x = self.adapter_norm(x) 210 x = self.adapter_linear_down(x) 211 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 212 x = torch.permute(x, (0, -1, 1, 2, 3)) 213 x = self.adapter_conv(x) 214 x = torch.permute(x, (0, 2, 3, 4, 1)) 215 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 216 x = self.adapter_act(x) 217 x = self.adapter_linear_up(x) 218 x = shortcut + x 219 # end 3D adapter 220 221 shortcut = x 222 x = self.block.norm1(x) 223 # Window partition 224 if self.block.window_size > 0: 225 H, W = x.shape[1], x.shape[2] 226 x, pad_hw = window_partition(x, self.block.window_size) 227 228 x = self.block.attn(x) 229 # Reverse window partition 230 if self.block.window_size > 0: 231 x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) 232 233 x = shortcut + x 234 235 # 3D adapter 236 shortcut = x 237 x = self.adapter_norm_2(x) 238 x = self.adapter_linear_down_2(x) 239 x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) 240 x = torch.permute(x, (0, -1, 1, 2, 3)) 241 x = self.adapter_conv_2(x) 242 x = torch.permute(x, (0, 2, 3, 4, 1)) 243 x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) 244 x = self.adapter_act_2(x) 245 x = self.adapter_linear_up_2(x) 246 x = shortcut + x 247 # end 3D adapter 248 249 x = x + self.block.mlp(self.block.norm2(x)) 250 251 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile