micro_sam.models.sam_3d_wrapper

  1import os
  2from typing import Any, List, Dict, Type, Union, Optional
  3
  4import torch
  5import torch.nn as nn
  6
  7from segment_anything.modeling import Sam
  8from segment_anything.modeling.image_encoder import window_partition, window_unpartition
  9
 10from ..util import get_sam_model
 11from .peft_sam import LoRASurgery
 12
 13
 14def get_sam_3d_model(
 15    device: Union[str, torch.device],
 16    n_classes: int,
 17    image_size: int,
 18    lora_rank: Optional[int] = None,
 19    freeze_encoder: bool = False,
 20    model_type: str = "vit_b",
 21    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 22):
 23    if lora_rank is None:
 24        peft_kwargs = {}
 25    else:
 26        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
 27
 28    _, sam = get_sam_model(
 29        model_type=model_type,
 30        device=device,
 31        checkpoint_path=checkpoint_path,
 32        return_sam=True,
 33        flexible_load_checkpoint=True,
 34        num_multimask_outputs=n_classes,
 35        image_size=image_size,
 36        peft_kwargs=peft_kwargs,
 37    )
 38
 39    # Make sure not to freeze the encoder when using LoRA.
 40    _freeze_encoder = freeze_encoder if lora_rank is None else False
 41    sam_3d = Sam3DWrapper(sam, freeze_encoder=_freeze_encoder, model_type=model_type)
 42    sam_3d.to(device)
 43
 44    return sam_3d
 45
 46
 47class Sam3DWrapper(nn.Module):
 48    def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"):
 49        """Initializes the Sam3DWrapper object.
 50
 51        Args:
 52            sam_model: The Sam model to be wrapped.
 53            freeze_encoder: Whether to freeze the image encoder.
 54            model_type: The choice of segment anything model to wrap adapters for respective model configuration.
 55        """
 56        super().__init__()
 57
 58        # Model configurations
 59        if model_type == "vit_b":
 60            embed_dim, num_heads = 768, 12
 61        elif model_type == "vit_l":
 62            embed_dim, num_heads = 1024, 16
 63        elif model_type == "vit_h":
 64            embed_dim, num_heads = 1280, 16
 65        else:
 66            raise ValueError(f"'{model_type}' is not a supported choice of model.")
 67
 68        sam_model.image_encoder = ImageEncoderViT3DWrapper(
 69            image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim,
 70        )
 71        self.sam_model = sam_model
 72
 73        self.freeze_encoder = freeze_encoder
 74        if self.freeze_encoder:
 75            for param in self.sam_model.image_encoder.parameters():
 76                param.requires_grad = False
 77
 78    def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
 79        """Predict 3D masks for the current inputs.
 80
 81        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 82
 83        Args:
 84            batched_input: A list over input images, each a dictionary with the following keys.
 85                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 86                'original_size': The original size of the image (HxW) before transformation.
 87            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 88
 89        Returns:
 90            A list over input images, where each element is as dictionary with the following keys:
 91                'masks': Mask prediction for this object.
 92                'iou_predictions': IOU score prediction for this object.
 93                'low_res_masks': Low resolution mask prediction for this object.
 94        """
 95        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 96        original_size = batched_input[0]["original_size"]
 97        assert all(inp["original_size"] == original_size for inp in batched_input)
 98
 99        # dimensions: [b, 3, d, h, w]
100        shape = batched_images.shape
101        assert shape[1] == 3
102        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
103        # Transpose the axes, so that the depth axis is the first axis and the channel
104        # axis is the second axis. This is expected by the transformer!
105        batched_images = batched_images.transpose(1, 2)
106        assert batched_images.shape[1] == d_size
107        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
108
109        input_images = self.sam_model.preprocess(batched_images)
110        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
111        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
112            points=None, boxes=None, masks=None
113        )
114        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
115            image_embeddings=image_embeddings,
116            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
117            sparse_prompt_embeddings=sparse_embeddings,
118            dense_prompt_embeddings=dense_embeddings,
119            multimask_output=multimask_output
120        )
121        masks = self.sam_model.postprocess_masks(
122            low_res_masks,
123            input_size=batched_images.shape[-2:],
124            original_size=original_size,
125        )
126
127        # Bring the masks and low-res masks into the correct shape:
128        # - disentangle batches and z-slices
129        # - rearrange output channels and z-slices
130
131        n_channels = masks.shape[1]
132        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
133        low_res_masks = low_res_masks.view(
134            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
135        )
136
137        masks = masks.transpose(1, 2)
138        low_res_masks = low_res_masks.transpose(1, 2)
139
140        # Make the output compatable with the SAM output.
141        outputs = [{
142            "masks": mask.unsqueeze(0),
143            "iou_predictions": iou_pred,
144            "low_res_logits": low_res_mask.unsqueeze(0)
145        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
146
147        return outputs
148
149
150class ImageEncoderViT3DWrapper(nn.Module):
151    def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768):
152
153        super().__init__()
154        self.image_encoder = image_encoder
155        self.img_size = self.image_encoder.img_size
156
157        # Replace default blocks with 3d adapter blocks
158        for i, blk in enumerate(self.image_encoder.blocks):
159            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
160
161    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
162        x = self.image_encoder.patch_embed(x)
163        if self.image_encoder.pos_embed is not None:
164            x = x + self.image_encoder.pos_embed
165
166        for blk in self.image_encoder.blocks:
167            x = blk(x, d_size)
168
169        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
170
171        return x
172
173
174class NDBlockWrapper(nn.Module):
175    def __init__(
176        self,
177        block: nn.Module,
178        dim: int,
179        num_heads: int,
180        norm_layer: Type[nn.Module] = nn.LayerNorm,
181        adapter_channels: int = 384,
182    ):
183        super().__init__()
184        self.block = block
185
186        self.adapter_channels = adapter_channels
187        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
188        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
189        self.adapter_conv = nn.Conv3d(
190            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
191        )
192        self.adapter_act = nn.GELU()
193        self.adapter_norm = norm_layer(dim)
194
195        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
196        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
197        self.adapter_conv_2 = nn.Conv3d(
198            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
199        )
200        self.adapter_act_2 = nn.GELU()
201        self.adapter_norm_2 = norm_layer(dim)
202
203    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
204        b_size, hw_size = x.shape[0], x.shape[1]
205
206        # 3D adapter
207        shortcut = x
208        x = self.adapter_norm(x)
209        x = self.adapter_linear_down(x)
210        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
211        x = torch.permute(x, (0, -1, 1, 2, 3))
212        x = self.adapter_conv(x)
213        x = torch.permute(x, (0, 2, 3, 4, 1))
214        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
215        x = self.adapter_act(x)
216        x = self.adapter_linear_up(x)
217        x = shortcut + x
218        # end 3D adapter
219
220        shortcut = x
221        x = self.block.norm1(x)
222        # Window partition
223        if self.block.window_size > 0:
224            H, W = x.shape[1], x.shape[2]
225            x, pad_hw = window_partition(x, self.block.window_size)
226
227        x = self.block.attn(x)
228        # Reverse window partition
229        if self.block.window_size > 0:
230            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
231
232        x = shortcut + x
233
234        # 3D adapter
235        shortcut = x
236        x = self.adapter_norm_2(x)
237        x = self.adapter_linear_down_2(x)
238        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
239        x = torch.permute(x, (0, -1, 1, 2, 3))
240        x = self.adapter_conv_2(x)
241        x = torch.permute(x, (0, 2, 3, 4, 1))
242        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
243        x = self.adapter_act_2(x)
244        x = self.adapter_linear_up_2(x)
245        x = shortcut + x
246        # end 3D adapter
247
248        x = x + self.block.mlp(self.block.norm2(x))
249
250        return x
def get_sam_3d_model( device: Union[str, torch.device], n_classes: int, image_size: int, lora_rank: Optional[int] = None, freeze_encoder: bool = False, model_type: str = 'vit_b', checkpoint_path: Union[os.PathLike, str, NoneType] = None):
15def get_sam_3d_model(
16    device: Union[str, torch.device],
17    n_classes: int,
18    image_size: int,
19    lora_rank: Optional[int] = None,
20    freeze_encoder: bool = False,
21    model_type: str = "vit_b",
22    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
23):
24    if lora_rank is None:
25        peft_kwargs = {}
26    else:
27        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
28
29    _, sam = get_sam_model(
30        model_type=model_type,
31        device=device,
32        checkpoint_path=checkpoint_path,
33        return_sam=True,
34        flexible_load_checkpoint=True,
35        num_multimask_outputs=n_classes,
36        image_size=image_size,
37        peft_kwargs=peft_kwargs,
38    )
39
40    # Make sure not to freeze the encoder when using LoRA.
41    _freeze_encoder = freeze_encoder if lora_rank is None else False
42    sam_3d = Sam3DWrapper(sam, freeze_encoder=_freeze_encoder, model_type=model_type)
43    sam_3d.to(device)
44
45    return sam_3d
class Sam3DWrapper(torch.nn.modules.module.Module):
 48class Sam3DWrapper(nn.Module):
 49    def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"):
 50        """Initializes the Sam3DWrapper object.
 51
 52        Args:
 53            sam_model: The Sam model to be wrapped.
 54            freeze_encoder: Whether to freeze the image encoder.
 55            model_type: The choice of segment anything model to wrap adapters for respective model configuration.
 56        """
 57        super().__init__()
 58
 59        # Model configurations
 60        if model_type == "vit_b":
 61            embed_dim, num_heads = 768, 12
 62        elif model_type == "vit_l":
 63            embed_dim, num_heads = 1024, 16
 64        elif model_type == "vit_h":
 65            embed_dim, num_heads = 1280, 16
 66        else:
 67            raise ValueError(f"'{model_type}' is not a supported choice of model.")
 68
 69        sam_model.image_encoder = ImageEncoderViT3DWrapper(
 70            image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim,
 71        )
 72        self.sam_model = sam_model
 73
 74        self.freeze_encoder = freeze_encoder
 75        if self.freeze_encoder:
 76            for param in self.sam_model.image_encoder.parameters():
 77                param.requires_grad = False
 78
 79    def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
 80        """Predict 3D masks for the current inputs.
 81
 82        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 83
 84        Args:
 85            batched_input: A list over input images, each a dictionary with the following keys.
 86                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 87                'original_size': The original size of the image (HxW) before transformation.
 88            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 89
 90        Returns:
 91            A list over input images, where each element is as dictionary with the following keys:
 92                'masks': Mask prediction for this object.
 93                'iou_predictions': IOU score prediction for this object.
 94                'low_res_masks': Low resolution mask prediction for this object.
 95        """
 96        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 97        original_size = batched_input[0]["original_size"]
 98        assert all(inp["original_size"] == original_size for inp in batched_input)
 99
100        # dimensions: [b, 3, d, h, w]
101        shape = batched_images.shape
102        assert shape[1] == 3
103        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
104        # Transpose the axes, so that the depth axis is the first axis and the channel
105        # axis is the second axis. This is expected by the transformer!
106        batched_images = batched_images.transpose(1, 2)
107        assert batched_images.shape[1] == d_size
108        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
109
110        input_images = self.sam_model.preprocess(batched_images)
111        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
112        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
113            points=None, boxes=None, masks=None
114        )
115        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
116            image_embeddings=image_embeddings,
117            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
118            sparse_prompt_embeddings=sparse_embeddings,
119            dense_prompt_embeddings=dense_embeddings,
120            multimask_output=multimask_output
121        )
122        masks = self.sam_model.postprocess_masks(
123            low_res_masks,
124            input_size=batched_images.shape[-2:],
125            original_size=original_size,
126        )
127
128        # Bring the masks and low-res masks into the correct shape:
129        # - disentangle batches and z-slices
130        # - rearrange output channels and z-slices
131
132        n_channels = masks.shape[1]
133        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
134        low_res_masks = low_res_masks.view(
135            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
136        )
137
138        masks = masks.transpose(1, 2)
139        low_res_masks = low_res_masks.transpose(1, 2)
140
141        # Make the output compatable with the SAM output.
142        outputs = [{
143            "masks": mask.unsqueeze(0),
144            "iou_predictions": iou_pred,
145            "low_res_logits": low_res_mask.unsqueeze(0)
146        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
147
148        return outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Sam3DWrapper( sam_model: segment_anything.modeling.sam.Sam, freeze_encoder: bool, model_type: str = 'vit_b')
49    def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"):
50        """Initializes the Sam3DWrapper object.
51
52        Args:
53            sam_model: The Sam model to be wrapped.
54            freeze_encoder: Whether to freeze the image encoder.
55            model_type: The choice of segment anything model to wrap adapters for respective model configuration.
56        """
57        super().__init__()
58
59        # Model configurations
60        if model_type == "vit_b":
61            embed_dim, num_heads = 768, 12
62        elif model_type == "vit_l":
63            embed_dim, num_heads = 1024, 16
64        elif model_type == "vit_h":
65            embed_dim, num_heads = 1280, 16
66        else:
67            raise ValueError(f"'{model_type}' is not a supported choice of model.")
68
69        sam_model.image_encoder = ImageEncoderViT3DWrapper(
70            image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim,
71        )
72        self.sam_model = sam_model
73
74        self.freeze_encoder = freeze_encoder
75        if self.freeze_encoder:
76            for param in self.sam_model.image_encoder.parameters():
77                param.requires_grad = False

Initializes the Sam3DWrapper object.

Arguments:
  • sam_model: The Sam model to be wrapped.
  • freeze_encoder: Whether to freeze the image encoder.
  • model_type: The choice of segment anything model to wrap adapters for respective model configuration.
sam_model
freeze_encoder
def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
 79    def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
 80        """Predict 3D masks for the current inputs.
 81
 82        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
 83
 84        Args:
 85            batched_input: A list over input images, each a dictionary with the following keys.
 86                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
 87                'original_size': The original size of the image (HxW) before transformation.
 88            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
 89
 90        Returns:
 91            A list over input images, where each element is as dictionary with the following keys:
 92                'masks': Mask prediction for this object.
 93                'iou_predictions': IOU score prediction for this object.
 94                'low_res_masks': Low resolution mask prediction for this object.
 95        """
 96        batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
 97        original_size = batched_input[0]["original_size"]
 98        assert all(inp["original_size"] == original_size for inp in batched_input)
 99
100        # dimensions: [b, 3, d, h, w]
101        shape = batched_images.shape
102        assert shape[1] == 3
103        batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
104        # Transpose the axes, so that the depth axis is the first axis and the channel
105        # axis is the second axis. This is expected by the transformer!
106        batched_images = batched_images.transpose(1, 2)
107        assert batched_images.shape[1] == d_size
108        batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
109
110        input_images = self.sam_model.preprocess(batched_images)
111        image_embeddings = self.sam_model.image_encoder(input_images, d_size)
112        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
113            points=None, boxes=None, masks=None
114        )
115        low_res_masks, iou_predictions = self.sam_model.mask_decoder(
116            image_embeddings=image_embeddings,
117            image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
118            sparse_prompt_embeddings=sparse_embeddings,
119            dense_prompt_embeddings=dense_embeddings,
120            multimask_output=multimask_output
121        )
122        masks = self.sam_model.postprocess_masks(
123            low_res_masks,
124            input_size=batched_images.shape[-2:],
125            original_size=original_size,
126        )
127
128        # Bring the masks and low-res masks into the correct shape:
129        # - disentangle batches and z-slices
130        # - rearrange output channels and z-slices
131
132        n_channels = masks.shape[1]
133        masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
134        low_res_masks = low_res_masks.view(
135            *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
136        )
137
138        masks = masks.transpose(1, 2)
139        low_res_masks = low_res_masks.transpose(1, 2)
140
141        # Make the output compatable with the SAM output.
142        outputs = [{
143            "masks": mask.unsqueeze(0),
144            "iou_predictions": iou_pred,
145            "low_res_logits": low_res_mask.unsqueeze(0)
146        } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
147
148        return outputs

Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Arguments:
  • batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation.
  • multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:

A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object. 'iou_predictions': IOU score prediction for this object. 'low_res_masks': Low resolution mask prediction for this object.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class ImageEncoderViT3DWrapper(torch.nn.modules.module.Module):
151class ImageEncoderViT3DWrapper(nn.Module):
152    def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768):
153
154        super().__init__()
155        self.image_encoder = image_encoder
156        self.img_size = self.image_encoder.img_size
157
158        # Replace default blocks with 3d adapter blocks
159        for i, blk in enumerate(self.image_encoder.blocks):
160            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)
161
162    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
163        x = self.image_encoder.patch_embed(x)
164        if self.image_encoder.pos_embed is not None:
165            x = x + self.image_encoder.pos_embed
166
167        for blk in self.image_encoder.blocks:
168            x = blk(x, d_size)
169
170        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
171
172        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ImageEncoderViT3DWrapper( image_encoder: torch.nn.modules.module.Module, num_heads: int = 12, embed_dim: int = 768)
152    def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768):
153
154        super().__init__()
155        self.image_encoder = image_encoder
156        self.img_size = self.image_encoder.img_size
157
158        # Replace default blocks with 3d adapter blocks
159        for i, blk in enumerate(self.image_encoder.blocks):
160            self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

image_encoder
img_size
def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
162    def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor:
163        x = self.image_encoder.patch_embed(x)
164        if self.image_encoder.pos_embed is not None:
165            x = x + self.image_encoder.pos_embed
166
167        for blk in self.image_encoder.blocks:
168            x = blk(x, d_size)
169
170        x = self.image_encoder.neck(x.permute(0, 3, 1, 2))
171
172        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class NDBlockWrapper(torch.nn.modules.module.Module):
175class NDBlockWrapper(nn.Module):
176    def __init__(
177        self,
178        block: nn.Module,
179        dim: int,
180        num_heads: int,
181        norm_layer: Type[nn.Module] = nn.LayerNorm,
182        adapter_channels: int = 384,
183    ):
184        super().__init__()
185        self.block = block
186
187        self.adapter_channels = adapter_channels
188        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
189        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
190        self.adapter_conv = nn.Conv3d(
191            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
192        )
193        self.adapter_act = nn.GELU()
194        self.adapter_norm = norm_layer(dim)
195
196        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
197        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
198        self.adapter_conv_2 = nn.Conv3d(
199            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
200        )
201        self.adapter_act_2 = nn.GELU()
202        self.adapter_norm_2 = norm_layer(dim)
203
204    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
205        b_size, hw_size = x.shape[0], x.shape[1]
206
207        # 3D adapter
208        shortcut = x
209        x = self.adapter_norm(x)
210        x = self.adapter_linear_down(x)
211        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
212        x = torch.permute(x, (0, -1, 1, 2, 3))
213        x = self.adapter_conv(x)
214        x = torch.permute(x, (0, 2, 3, 4, 1))
215        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
216        x = self.adapter_act(x)
217        x = self.adapter_linear_up(x)
218        x = shortcut + x
219        # end 3D adapter
220
221        shortcut = x
222        x = self.block.norm1(x)
223        # Window partition
224        if self.block.window_size > 0:
225            H, W = x.shape[1], x.shape[2]
226            x, pad_hw = window_partition(x, self.block.window_size)
227
228        x = self.block.attn(x)
229        # Reverse window partition
230        if self.block.window_size > 0:
231            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
232
233        x = shortcut + x
234
235        # 3D adapter
236        shortcut = x
237        x = self.adapter_norm_2(x)
238        x = self.adapter_linear_down_2(x)
239        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
240        x = torch.permute(x, (0, -1, 1, 2, 3))
241        x = self.adapter_conv_2(x)
242        x = torch.permute(x, (0, 2, 3, 4, 1))
243        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
244        x = self.adapter_act_2(x)
245        x = self.adapter_linear_up_2(x)
246        x = shortcut + x
247        # end 3D adapter
248
249        x = x + self.block.mlp(self.block.norm2(x))
250
251        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

NDBlockWrapper( block: torch.nn.modules.module.Module, dim: int, num_heads: int, norm_layer: Type[torch.nn.modules.module.Module] = <class 'torch.nn.modules.normalization.LayerNorm'>, adapter_channels: int = 384)
176    def __init__(
177        self,
178        block: nn.Module,
179        dim: int,
180        num_heads: int,
181        norm_layer: Type[nn.Module] = nn.LayerNorm,
182        adapter_channels: int = 384,
183    ):
184        super().__init__()
185        self.block = block
186
187        self.adapter_channels = adapter_channels
188        self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False)
189        self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False)
190        self.adapter_conv = nn.Conv3d(
191            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
192        )
193        self.adapter_act = nn.GELU()
194        self.adapter_norm = norm_layer(dim)
195
196        self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False)
197        self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False)
198        self.adapter_conv_2 = nn.Conv3d(
199            self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same"
200        )
201        self.adapter_act_2 = nn.GELU()
202        self.adapter_norm_2 = norm_layer(dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

block
adapter_channels
adapter_linear_down
adapter_linear_up
adapter_conv
adapter_act
adapter_norm
adapter_linear_down_2
adapter_linear_up_2
adapter_conv_2
adapter_act_2
adapter_norm_2
def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
204    def forward(self, x: torch.Tensor, d_size) -> torch.Tensor:
205        b_size, hw_size = x.shape[0], x.shape[1]
206
207        # 3D adapter
208        shortcut = x
209        x = self.adapter_norm(x)
210        x = self.adapter_linear_down(x)
211        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
212        x = torch.permute(x, (0, -1, 1, 2, 3))
213        x = self.adapter_conv(x)
214        x = torch.permute(x, (0, 2, 3, 4, 1))
215        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
216        x = self.adapter_act(x)
217        x = self.adapter_linear_up(x)
218        x = shortcut + x
219        # end 3D adapter
220
221        shortcut = x
222        x = self.block.norm1(x)
223        # Window partition
224        if self.block.window_size > 0:
225            H, W = x.shape[1], x.shape[2]
226            x, pad_hw = window_partition(x, self.block.window_size)
227
228        x = self.block.attn(x)
229        # Reverse window partition
230        if self.block.window_size > 0:
231            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
232
233        x = shortcut + x
234
235        # 3D adapter
236        shortcut = x
237        x = self.adapter_norm_2(x)
238        x = self.adapter_linear_down_2(x)
239        x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels)
240        x = torch.permute(x, (0, -1, 1, 2, 3))
241        x = self.adapter_conv_2(x)
242        x = torch.permute(x, (0, 2, 3, 4, 1))
243        x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels)
244        x = self.adapter_act_2(x)
245        x = self.adapter_linear_up_2(x)
246        x = shortcut + x
247        # end 3D adapter
248
249        x = x + self.block.mlp(self.block.norm2(x))
250
251        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile