micro_sam.models.sam_3d_wrapper

  1from typing import Any, List, Dict, Type
  2
  3import torch
  4import torch.nn as nn
  5
  6from segment_anything.modeling import Sam
  7from segment_anything.modeling.image_encoder import window_partition, window_unpartition
  8
  9from ..util import get_sam_model
 10from .peft_sam import LoRASurgery
 11
 12
 13def get_sam_3d_model(
 14    device,
 15    n_classes,
 16    image_size,
 17    lora_rank=None,
 18    freeze_encoder=False,
 19    model_type="vit_b",
 20    checkpoint_path=None,
 21):
 22    if lora_rank is None:
 23        peft_kwargs = {}
 24    else:
 25        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
 26
 27    _, sam = get_sam_model(
 28        model_type=model_type,
 29        device=device,
 30        checkpoint_path=checkpoint_path,
 31        return_sam=True,
 32        flexible_load_checkpoint=True,
 33        num_multimask_outputs=n_classes,
 34        image_size=image_size,
 35        peft_kwargs=peft_kwargs,
 36    )
 37
 38    # Make sure not to freeze the encoder when using LoRA.
 39    freeze_encoder_ = freeze_encoder if lora_rank is None else False
 40    sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_)
 41    sam_3d.to(device)
 42    return sam_3d
 43
 44
 45class Sam3DWrapper(nn.Module):
 46    def __init__(self, sam_model: Sam, freeze_encoder: bool):
 47        """Initializes the Sam3DWrapper object.
 48
 49        Args:
 50            sam_model: The Sam model to be wrapped.
 51            freeze_encoder: Whether to freeze the image encoder.
 52        """
 53        super().__init__()
 54        sam_model.image_encoder = ImageEncoderViT3DWrapper(
 55            image_encoder=sam_model.image_encoder
 56        )
 57        self.sam_model = sam_model
 58
 59        self.freeze_encoder = freeze_encoder
 60        if self.freeze_encoder:
 61            for param in self.sam_model.image_encoder.parameters():
 62                param.requires_grad = False
 63
 64    def forward(
 65        self,
 66        batched_input: List[Dict[str, Any]],
 67        multimask_output: bool
 68    ) -> List[Dict[str, torch.Tensor]]:
 69        """Predict 3D masks for the current inputs.
 70
 71        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 72
 73        Args:
 74            batched_input: A list over input images, each a dictionary with the following keys.
 75                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 76                'original_size': The original size of the image (HxW) before transformation.
 77            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 78
 79        Returns:
 80            A list over input images, where each element is as dictionary with the following keys:
 81                'masks': Mask prediction for this object.
 82                'iou_predictions': IOU score prediction for this object.
 83                'low_res_masks': Low resolution mask prediction for this object.
 84        """
 85        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 86        original_size = batched_input[0]["original_size"]
 87        assert all(inp["original_size"] == original_size for inp in batched_input)
 88
 89        # dimensions: [b, 3, d, h, w]
 90        shape = batched_images.shape
 91        assert shape[1] == 3
 92        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
 93        # Transpose the axes, so that the depth axis is the first axis and the channel
 94        # axis is the second axis. This is expected by the transformer!
 95        batched_images = batched_images.transpose(1, 2)
 96        assert batched_images.shape[1] == d_size
 97        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
 98
 99        input_images = self.sam_model.preprocess(batched_images)
100        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
101        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
102            points=None, boxes=None, masks=None
103        )
104        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
105            image_embeddings=image_embeddings,
106            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
107            sparse_prompt_embeddings=sparse_embeddings,
108            dense_prompt_embeddings=dense_embeddings,
109            multimask_output=multimask_output
110        )
111        masks = self.sam_model.postprocess_masks(
112            low_res_masks,
113            input_size=batched_images.shape[-2:],
114            original_size=original_size,
115        )
116
117        # Bring the masks and low-res masks into the correct shape:
118        # - disentangle batches and z-slices
119        # - rearrange output channels and z-slices
120
121        n_channels = masks.shape[1]
122        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
123        low_res_masks = low_res_masks.view(
124            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
125        )
126
127        masks = masks.transpose(1, 2)
128        low_res_masks = low_res_masks.transpose(1, 2)
129
130        # Make the output compatable with the SAM output.
131        outputs = [{
132            "masks": mask.unsqueeze(0),
133            "iou_predictions": iou_pred,
134            "low_res_logits": low_res_mask.unsqueeze(0)
135        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
136        return outputs
137
138
139class ImageEncoderViT3DWrapper(nn.Module):
140    def __init__(
141        self,
142        image_encoder: nn.Module,
143        num_heads: int = 12,
144        embed_dim: int = 768,
145    ):
146        super().__init__()
147        self.image_encoder = image_encoder
148        self.img_size = self.image_encoder.img_size
149
150        # replace default blocks with 3d adapter blocks
151        for i, blk in enumerate(self.image_encoder.blocks):
152            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
153
154    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
155        x = self.image_encoder.patch_embed(x)
156        if self.image_encoder.pos_embed is not None:
157            x = x + self.image_encoder.pos_embed
158
159        for blk in self.image_encoder.blocks:
160            x = blk(x, d_size)
161
162        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
163
164        return x
165
166
167class NDBlockWrapper(nn.Module):
168    def __init__(
169        self,
170        block: nn.Module,
171        dim: int,
172        num_heads: int,
173        norm_layer: Type[nn.Module] = nn.LayerNorm,
174        adapter_channels: int = 384,
175    ):
176        super().__init__()
177        self.block = block
178
179        self.adapter_channels = adapter_channels
180        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
181        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
182        self.adapter_conv = nn.Conv3d(
183            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
184        )
185        self.adapter_act = nn.GELU()
186        self.adapter_norm = norm_layer(dim)
187
188        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
189        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
190        self.adapter_conv_2 = nn.Conv3d(
191            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
192        )
193        self.adapter_act_2 = nn.GELU()
194        self.adapter_norm_2 = norm_layer(dim)
195
196    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
197        b_size, hw_size = x.shape[0], x.shape[1]
198
199        # 3D adapter
200        shortcut = x
201        x = self.adapter_norm(x)
202        x = self.adapter_linear_down(x)
203        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
204        x = torch.permute(x, (0, -1, 1, 2, 3))
205        x = self.adapter_conv(x)
206        x = torch.permute(x, (0, 2, 3, 4, 1))
207        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
208        x = self.adapter_act(x)
209        x = self.adapter_linear_up(x)
210        x = shortcut + x
211        # end 3D adapter
212
213        shortcut = x
214        x = self.block.norm1(x)
215        # Window partition
216        if self.block.window_size > 0:
217            H, W = x.shape[1], x.shape[2]
218            x, pad_hw = window_partition(x, self.block.window_size)
219
220        x = self.block.attn(x)
221        # Reverse window partition
222        if self.block.window_size > 0:
223            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
224
225        x = shortcut + x
226
227        # 3D adapter
228        shortcut = x
229        x = self.adapter_norm_2(x)
230        x = self.adapter_linear_down_2(x)
231        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
232        x = torch.permute(x, (0, -1, 1, 2, 3))
233        x = self.adapter_conv_2(x)
234        x = torch.permute(x, (0, 2, 3, 4, 1))
235        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
236        x = self.adapter_act_2(x)
237        x = self.adapter_linear_up_2(x)
238        x = shortcut + x
239        # end 3D adapter
240
241        x = x + self.block.mlp(self.block.norm2(x))
242
243        return x
def get_sam_3d_model( device, n_classes, image_size, lora_rank=None, freeze_encoder=False, model_type='vit_b', checkpoint_path=None):
14def get_sam_3d_model(
15    device,
16    n_classes,
17    image_size,
18    lora_rank=None,
19    freeze_encoder=False,
20    model_type="vit_b",
21    checkpoint_path=None,
22):
23    if lora_rank is None:
24        peft_kwargs = {}
25    else:
26        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
27
28    _, sam = get_sam_model(
29        model_type=model_type,
30        device=device,
31        checkpoint_path=checkpoint_path,
32        return_sam=True,
33        flexible_load_checkpoint=True,
34        num_multimask_outputs=n_classes,
35        image_size=image_size,
36        peft_kwargs=peft_kwargs,
37    )
38
39    # Make sure not to freeze the encoder when using LoRA.
40    freeze_encoder_ = freeze_encoder if lora_rank is None else False
41    sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_)
42    sam_3d.to(device)
43    return sam_3d
class Sam3DWrapper(torch.nn.modules.module.Module):
 46class Sam3DWrapper(nn.Module):
 47    def __init__(self, sam_model: Sam, freeze_encoder: bool):
 48        """Initializes the Sam3DWrapper object.
 49
 50        Args:
 51            sam_model: The Sam model to be wrapped.
 52            freeze_encoder: Whether to freeze the image encoder.
 53        """
 54        super().__init__()
 55        sam_model.image_encoder = ImageEncoderViT3DWrapper(
 56            image_encoder=sam_model.image_encoder
 57        )
 58        self.sam_model = sam_model
 59
 60        self.freeze_encoder = freeze_encoder
 61        if self.freeze_encoder:
 62            for param in self.sam_model.image_encoder.parameters():
 63                param.requires_grad = False
 64
 65    def forward(
 66        self,
 67        batched_input: List[Dict[str, Any]],
 68        multimask_output: bool
 69    ) -> List[Dict[str, torch.Tensor]]:
 70        """Predict 3D masks for the current inputs.
 71
 72        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 73
 74        Args:
 75            batched_input: A list over input images, each a dictionary with the following keys.
 76                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 77                'original_size': The original size of the image (HxW) before transformation.
 78            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 79
 80        Returns:
 81            A list over input images, where each element is as dictionary with the following keys:
 82                'masks': Mask prediction for this object.
 83                'iou_predictions': IOU score prediction for this object.
 84                'low_res_masks': Low resolution mask prediction for this object.
 85        """
 86        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 87        original_size = batched_input[0]["original_size"]
 88        assert all(inp["original_size"] == original_size for inp in batched_input)
 89
 90        # dimensions: [b, 3, d, h, w]
 91        shape = batched_images.shape
 92        assert shape[1] == 3
 93        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
 94        # Transpose the axes, so that the depth axis is the first axis and the channel
 95        # axis is the second axis. This is expected by the transformer!
 96        batched_images = batched_images.transpose(1, 2)
 97        assert batched_images.shape[1] == d_size
 98        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
 99
100        input_images = self.sam_model.preprocess(batched_images)
101        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
102        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
103            points=None, boxes=None, masks=None
104        )
105        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
106            image_embeddings=image_embeddings,
107            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
108            sparse_prompt_embeddings=sparse_embeddings,
109            dense_prompt_embeddings=dense_embeddings,
110            multimask_output=multimask_output
111        )
112        masks = self.sam_model.postprocess_masks(
113            low_res_masks,
114            input_size=batched_images.shape[-2:],
115            original_size=original_size,
116        )
117
118        # Bring the masks and low-res masks into the correct shape:
119        # - disentangle batches and z-slices
120        # - rearrange output channels and z-slices
121
122        n_channels = masks.shape[1]
123        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
124        low_res_masks = low_res_masks.view(
125            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
126        )
127
128        masks = masks.transpose(1, 2)
129        low_res_masks = low_res_masks.transpose(1, 2)
130
131        # Make the output compatable with the SAM output.
132        outputs = [{
133            "masks": mask.unsqueeze(0),
134            "iou_predictions": iou_pred,
135            "low_res_logits": low_res_mask.unsqueeze(0)
136        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
137        return outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Sam3DWrapper(sam_model: segment_anything.modeling.sam.Sam, freeze_encoder: bool)
47    def __init__(self, sam_model: Sam, freeze_encoder: bool):
48        """Initializes the Sam3DWrapper object.
49
50        Args:
51            sam_model: The Sam model to be wrapped.
52            freeze_encoder: Whether to freeze the image encoder.
53        """
54        super().__init__()
55        sam_model.image_encoder = ImageEncoderViT3DWrapper(
56            image_encoder=sam_model.image_encoder
57        )
58        self.sam_model = sam_model
59
60        self.freeze_encoder = freeze_encoder
61        if self.freeze_encoder:
62            for param in self.sam_model.image_encoder.parameters():
63                param.requires_grad = False

Initializes the Sam3DWrapper object.

Arguments:
  • sam_model: The Sam model to be wrapped.
  • freeze_encoder: Whether to freeze the image encoder.
sam_model
freeze_encoder
def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
 65    def forward(
 66        self,
 67        batched_input: List[Dict[str, Any]],
 68        multimask_output: bool
 69    ) -> List[Dict[str, torch.Tensor]]:
 70        """Predict 3D masks for the current inputs.
 71
 72        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 73
 74        Args:
 75            batched_input: A list over input images, each a dictionary with the following keys.
 76                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 77                'original_size': The original size of the image (HxW) before transformation.
 78            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 79
 80        Returns:
 81            A list over input images, where each element is as dictionary with the following keys:
 82                'masks': Mask prediction for this object.
 83                'iou_predictions': IOU score prediction for this object.
 84                'low_res_masks': Low resolution mask prediction for this object.
 85        """
 86        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 87        original_size = batched_input[0]["original_size"]
 88        assert all(inp["original_size"] == original_size for inp in batched_input)
 89
 90        # dimensions: [b, 3, d, h, w]
 91        shape = batched_images.shape
 92        assert shape[1] == 3
 93        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
 94        # Transpose the axes, so that the depth axis is the first axis and the channel
 95        # axis is the second axis. This is expected by the transformer!
 96        batched_images = batched_images.transpose(1, 2)
 97        assert batched_images.shape[1] == d_size
 98        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
 99
100        input_images = self.sam_model.preprocess(batched_images)
101        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
102        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
103            points=None, boxes=None, masks=None
104        )
105        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
106            image_embeddings=image_embeddings,
107            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
108            sparse_prompt_embeddings=sparse_embeddings,
109            dense_prompt_embeddings=dense_embeddings,
110            multimask_output=multimask_output
111        )
112        masks = self.sam_model.postprocess_masks(
113            low_res_masks,
114            input_size=batched_images.shape[-2:],
115            original_size=original_size,
116        )
117
118        # Bring the masks and low-res masks into the correct shape:
119        # - disentangle batches and z-slices
120        # - rearrange output channels and z-slices
121
122        n_channels = masks.shape[1]
123        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
124        low_res_masks = low_res_masks.view(
125            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
126        )
127
128        masks = masks.transpose(1, 2)
129        low_res_masks = low_res_masks.transpose(1, 2)
130
131        # Make the output compatable with the SAM output.
132        outputs = [{
133            "masks": mask.unsqueeze(0),
134            "iou_predictions": iou_pred,
135            "low_res_logits": low_res_mask.unsqueeze(0)
136        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
137        return outputs

Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Arguments:
  • batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation.
  • multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:

A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object. 'iou_predictions': IOU score prediction for this object. 'low_res_masks': Low resolution mask prediction for this object.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ImageEncoderViT3DWrapper(torch.nn.modules.module.Module):
140class ImageEncoderViT3DWrapper(nn.Module):
141    def __init__(
142        self,
143        image_encoder: nn.Module,
144        num_heads: int = 12,
145        embed_dim: int = 768,
146    ):
147        super().__init__()
148        self.image_encoder = image_encoder
149        self.img_size = self.image_encoder.img_size
150
151        # replace default blocks with 3d adapter blocks
152        for i, blk in enumerate(self.image_encoder.blocks):
153            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
154
155    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
156        x = self.image_encoder.patch_embed(x)
157        if self.image_encoder.pos_embed is not None:
158            x = x + self.image_encoder.pos_embed
159
160        for blk in self.image_encoder.blocks:
161            x = blk(x, d_size)
162
163        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
164
165        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ImageEncoderViT3DWrapper( image_encoder: torch.nn.modules.module.Module, num_heads: int = 12, embed_dim: int = 768)
141    def __init__(
142        self,
143        image_encoder: nn.Module,
144        num_heads: int = 12,
145        embed_dim: int = 768,
146    ):
147        super().__init__()
148        self.image_encoder = image_encoder
149        self.img_size = self.image_encoder.img_size
150
151        # replace default blocks with 3d adapter blocks
152        for i, blk in enumerate(self.image_encoder.blocks):
153            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

image_encoder
img_size
def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
155    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
156        x = self.image_encoder.patch_embed(x)
157        if self.image_encoder.pos_embed is not None:
158            x = x + self.image_encoder.pos_embed
159
160        for blk in self.image_encoder.blocks:
161            x = blk(x, d_size)
162
163        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
164
165        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class NDBlockWrapper(torch.nn.modules.module.Module):
168class NDBlockWrapper(nn.Module):
169    def __init__(
170        self,
171        block: nn.Module,
172        dim: int,
173        num_heads: int,
174        norm_layer: Type[nn.Module] = nn.LayerNorm,
175        adapter_channels: int = 384,
176    ):
177        super().__init__()
178        self.block = block
179
180        self.adapter_channels = adapter_channels
181        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
182        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
183        self.adapter_conv = nn.Conv3d(
184            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
185        )
186        self.adapter_act = nn.GELU()
187        self.adapter_norm = norm_layer(dim)
188
189        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
190        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
191        self.adapter_conv_2 = nn.Conv3d(
192            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
193        )
194        self.adapter_act_2 = nn.GELU()
195        self.adapter_norm_2 = norm_layer(dim)
196
197    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
198        b_size, hw_size = x.shape[0], x.shape[1]
199
200        # 3D adapter
201        shortcut = x
202        x = self.adapter_norm(x)
203        x = self.adapter_linear_down(x)
204        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
205        x = torch.permute(x, (0, -1, 1, 2, 3))
206        x = self.adapter_conv(x)
207        x = torch.permute(x, (0, 2, 3, 4, 1))
208        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
209        x = self.adapter_act(x)
210        x = self.adapter_linear_up(x)
211        x = shortcut + x
212        # end 3D adapter
213
214        shortcut = x
215        x = self.block.norm1(x)
216        # Window partition
217        if self.block.window_size > 0:
218            H, W = x.shape[1], x.shape[2]
219            x, pad_hw = window_partition(x, self.block.window_size)
220
221        x = self.block.attn(x)
222        # Reverse window partition
223        if self.block.window_size > 0:
224            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
225
226        x = shortcut + x
227
228        # 3D adapter
229        shortcut = x
230        x = self.adapter_norm_2(x)
231        x = self.adapter_linear_down_2(x)
232        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
233        x = torch.permute(x, (0, -1, 1, 2, 3))
234        x = self.adapter_conv_2(x)
235        x = torch.permute(x, (0, 2, 3, 4, 1))
236        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
237        x = self.adapter_act_2(x)
238        x = self.adapter_linear_up_2(x)
239        x = shortcut + x
240        # end 3D adapter
241
242        x = x + self.block.mlp(self.block.norm2(x))
243
244        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

NDBlockWrapper( block: torch.nn.modules.module.Module, dim: int, num_heads: int, norm_layer: Type[torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, adapter_channels: int = 384)
169    def __init__(
170        self,
171        block: nn.Module,
172        dim: int,
173        num_heads: int,
174        norm_layer: Type[nn.Module] = nn.LayerNorm,
175        adapter_channels: int = 384,
176    ):
177        super().__init__()
178        self.block = block
179
180        self.adapter_channels = adapter_channels
181        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
182        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
183        self.adapter_conv = nn.Conv3d(
184            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
185        )
186        self.adapter_act = nn.GELU()
187        self.adapter_norm = norm_layer(dim)
188
189        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
190        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
191        self.adapter_conv_2 = nn.Conv3d(
192            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
193        )
194        self.adapter_act_2 = nn.GELU()
195        self.adapter_norm_2 = norm_layer(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
adapter_channels
adapter_linear_down
adapter_linear_up
adapter_conv
adapter_act
adapter_norm
adapter_linear_down_2
adapter_linear_up_2
adapter_conv_2
adapter_act_2
adapter_norm_2
def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
197    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
198        b_size, hw_size = x.shape[0], x.shape[1]
199
200        # 3D adapter
201        shortcut = x
202        x = self.adapter_norm(x)
203        x = self.adapter_linear_down(x)
204        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
205        x = torch.permute(x, (0, -1, 1, 2, 3))
206        x = self.adapter_conv(x)
207        x = torch.permute(x, (0, 2, 3, 4, 1))
208        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
209        x = self.adapter_act(x)
210        x = self.adapter_linear_up(x)
211        x = shortcut + x
212        # end 3D adapter
213
214        shortcut = x
215        x = self.block.norm1(x)
216        # Window partition
217        if self.block.window_size > 0:
218            H, W = x.shape[1], x.shape[2]
219            x, pad_hw = window_partition(x, self.block.window_size)
220
221        x = self.block.attn(x)
222        # Reverse window partition
223        if self.block.window_size > 0:
224            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
225
226        x = shortcut + x
227
228        # 3D adapter
229        shortcut = x
230        x = self.adapter_norm_2(x)
231        x = self.adapter_linear_down_2(x)
232        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
233        x = torch.permute(x, (0, -1, 1, 2, 3))
234        x = self.adapter_conv_2(x)
235        x = torch.permute(x, (0, 2, 3, 4, 1))
236        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
237        x = self.adapter_act_2(x)
238        x = self.adapter_linear_up_2(x)
239        x = shortcut + x
240        # end 3D adapter
241
242        x = x + self.block.mlp(self.block.norm2(x))
243
244        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile