micro_sam.models.simple_sam_3d_wrapper

  1import os
  2from contextlib import nullcontext
  3from typing import Any, List, Dict, Union, Optional
  4
  5import torch
  6import torch.nn as nn
  7
  8from ..util import get_sam_model
  9from .peft_sam import LoRASurgery
 10
 11
 12def get_simple_sam_3d_model(
 13    device: Union[str, torch.device],
 14    n_classes: int,
 15    image_size: int,
 16    lora_rank: Optional[int] = None,
 17    freeze_encoder: bool = False,
 18    model_type: str = "vit_b",
 19    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 20):
 21    if lora_rank is None:
 22        peft_kwargs = {}
 23    else:
 24        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
 25
 26    _, sam = get_sam_model(
 27        model_type=model_type,
 28        device=device,
 29        checkpoint_path=checkpoint_path,
 30        return_sam=True,
 31        image_size=image_size,
 32        flexible_load_checkpoint=True,
 33        peft_kwargs=peft_kwargs,
 34    )
 35
 36    # Make sure not to freeze the encoder when using LoRA.
 37    freeze_encoder_ = freeze_encoder if lora_rank is None else False
 38    sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_)
 39    sam_3d.to(device)
 40
 41    return sam_3d
 42
 43
 44class BasicBlock(nn.Module):
 45    def __init__(
 46        self,
 47        in_channels,
 48        out_channels,
 49        kernel_size=(3, 3, 3),
 50        stride=(1, 1, 1),
 51        padding=(1, 1, 1),
 52        bias=True,
 53        mode="nearest"
 54    ):
 55        super().__init__()
 56
 57        self.conv1 = nn.Sequential(
 58            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
 59            nn.InstanceNorm3d(out_channels),
 60            nn.LeakyReLU()
 61        )
 62
 63        self.conv2 = nn.Sequential(
 64            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
 65            nn.InstanceNorm3d(out_channels)
 66        )
 67
 68        self.downsample = nn.Sequential(
 69            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
 70            nn.InstanceNorm3d(out_channels)
 71        )
 72
 73        self.leakyrelu = nn.LeakyReLU()
 74
 75        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
 76
 77    def forward(self, x):
 78        residual = self.downsample(x)
 79
 80        out = self.conv1(x)
 81        out = self.conv2(out)
 82        out += residual
 83
 84        out = self.leakyrelu(out)
 85        out = self.up(out)
 86        return out
 87
 88
 89class SegmentationHead(nn.Sequential):
 90    def __init__(
 91        self,
 92        in_channels,
 93        out_channels,
 94        kernel_size=(3, 3, 3),
 95        stride=(1, 1, 1),
 96        padding=(1, 1, 1),
 97        bias=True
 98    ):
 99        super().__init__()
100
101        self.conv_pred = nn.Sequential(
102            nn.Conv3d(
103                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
104            ),
105            nn.InstanceNorm3d(in_channels // 2),
106            nn.LeakyReLU()
107        )
108        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
109
110    def forward(self, x):
111        x = self.conv_pred(x)
112        return self.segmentation_head(x)
113
114
115class SimpleSam3DWrapper(nn.Module):
116    def __init__(self, sam, num_classes, freeze_encoder):
117        super().__init__()
118
119        self.sam = sam
120        self.freeze_encoder = freeze_encoder
121        if self.freeze_encoder:
122            for param in self.sam.image_encoder.parameters():
123                param.requires_grad = False
124            self.no_grad = torch.no_grad
125
126        else:
127            self.no_grad = nullcontext
128
129        self.decoders = nn.ModuleList([
130            BasicBlock(in_channels=256, out_channels=128),
131            BasicBlock(in_channels=128, out_channels=64),
132            BasicBlock(in_channels=64, out_channels=32),
133            BasicBlock(in_channels=32, out_channels=16),
134        ])
135        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
136
137    def _apply_image_encoder(self, x, D):
138        encoder_features = []
139        for d in range(D):
140            image = x[:, :, d]
141            feature = self.sam.image_encoder(image)
142            encoder_features.append(feature)
143        encoder_features = torch.stack(encoder_features, 2)
144        return encoder_features
145
146    def forward(
147        self, batched_input: List[Dict[str, Any]], multimask_output: bool
148    ) -> List[Dict[str, torch.Tensor]]:
149        """Predict 3D masks for the current inputs.
150
151        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
152
153        Args:
154            batched_input: A list over input images, each a dictionary with the following keys.L
155                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
156            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
157
158        Returns:
159            A list over input images, where each element is as dictionary with the following keys:
160                'masks': Mask prediction for this object.
161        """
162        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
163
164        B, C, D, H, W = x.shape
165        assert C == 3
166
167        with self.no_grad():
168            features = self._apply_image_encoder(x, D)
169
170        out = features
171        for decoder in self.decoders:
172            out = decoder(out)
173        logits = self.out_conv(out)
174
175        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
176        return outputs
def get_simple_sam_3d_model( device: Union[str, torch.device], n_classes: int, image_size: int, lora_rank: Optional[int] = None, freeze_encoder: bool = False, model_type: str = 'vit_b', checkpoint_path: Union[os.PathLike, str, NoneType] = None):
13def get_simple_sam_3d_model(
14    device: Union[str, torch.device],
15    n_classes: int,
16    image_size: int,
17    lora_rank: Optional[int] = None,
18    freeze_encoder: bool = False,
19    model_type: str = "vit_b",
20    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
21):
22    if lora_rank is None:
23        peft_kwargs = {}
24    else:
25        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
26
27    _, sam = get_sam_model(
28        model_type=model_type,
29        device=device,
30        checkpoint_path=checkpoint_path,
31        return_sam=True,
32        image_size=image_size,
33        flexible_load_checkpoint=True,
34        peft_kwargs=peft_kwargs,
35    )
36
37    # Make sure not to freeze the encoder when using LoRA.
38    freeze_encoder_ = freeze_encoder if lora_rank is None else False
39    sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_)
40    sam_3d.to(device)
41
42    return sam_3d
class BasicBlock(torch.nn.modules.module.Module):
45class BasicBlock(nn.Module):
46    def __init__(
47        self,
48        in_channels,
49        out_channels,
50        kernel_size=(3, 3, 3),
51        stride=(1, 1, 1),
52        padding=(1, 1, 1),
53        bias=True,
54        mode="nearest"
55    ):
56        super().__init__()
57
58        self.conv1 = nn.Sequential(
59            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
60            nn.InstanceNorm3d(out_channels),
61            nn.LeakyReLU()
62        )
63
64        self.conv2 = nn.Sequential(
65            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
66            nn.InstanceNorm3d(out_channels)
67        )
68
69        self.downsample = nn.Sequential(
70            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
71            nn.InstanceNorm3d(out_channels)
72        )
73
74        self.leakyrelu = nn.LeakyReLU()
75
76        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
77
78    def forward(self, x):
79        residual = self.downsample(x)
80
81        out = self.conv1(x)
82        out = self.conv2(out)
83        out += residual
84
85        out = self.leakyrelu(out)
86        out = self.up(out)
87        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

BasicBlock( in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=True, mode='nearest')
46    def __init__(
47        self,
48        in_channels,
49        out_channels,
50        kernel_size=(3, 3, 3),
51        stride=(1, 1, 1),
52        padding=(1, 1, 1),
53        bias=True,
54        mode="nearest"
55    ):
56        super().__init__()
57
58        self.conv1 = nn.Sequential(
59            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
60            nn.InstanceNorm3d(out_channels),
61            nn.LeakyReLU()
62        )
63
64        self.conv2 = nn.Sequential(
65            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
66            nn.InstanceNorm3d(out_channels)
67        )
68
69        self.downsample = nn.Sequential(
70            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
71            nn.InstanceNorm3d(out_channels)
72        )
73
74        self.leakyrelu = nn.LeakyReLU()
75
76        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv1
conv2
downsample
leakyrelu
up
def forward(self, x):
78    def forward(self, x):
79        residual = self.downsample(x)
80
81        out = self.conv1(x)
82        out = self.conv2(out)
83        out += residual
84
85        out = self.leakyrelu(out)
86        out = self.up(out)
87        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SegmentationHead(torch.nn.modules.container.Sequential):
 90class SegmentationHead(nn.Sequential):
 91    def __init__(
 92        self,
 93        in_channels,
 94        out_channels,
 95        kernel_size=(3, 3, 3),
 96        stride=(1, 1, 1),
 97        padding=(1, 1, 1),
 98        bias=True
 99    ):
100        super().__init__()
101
102        self.conv_pred = nn.Sequential(
103            nn.Conv3d(
104                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
105            ),
106            nn.InstanceNorm3d(in_channels // 2),
107            nn.LeakyReLU()
108        )
109        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
110
111    def forward(self, x):
112        x = self.conv_pred(x)
113        return self.segmentation_head(x)

A sequential container.

Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict of modules can be passed in. The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then "chains" outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

The value a Sequential provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential applies to each of the modules it stores (which are each a registered submodule of the Sequential).

What's the difference between a Sequential and a torch.nn.ModuleList? A ModuleList is exactly what it sounds like--a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
SegmentationHead( in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=True)
 91    def __init__(
 92        self,
 93        in_channels,
 94        out_channels,
 95        kernel_size=(3, 3, 3),
 96        stride=(1, 1, 1),
 97        padding=(1, 1, 1),
 98        bias=True
 99    ):
100        super().__init__()
101
102        self.conv_pred = nn.Sequential(
103            nn.Conv3d(
104                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
105            ),
106            nn.InstanceNorm3d(in_channels // 2),
107            nn.LeakyReLU()
108        )
109        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_pred
segmentation_head
def forward(self, x):
111    def forward(self, x):
112        x = self.conv_pred(x)
113        return self.segmentation_head(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.container.Sequential
pop
append
insert
extend
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SimpleSam3DWrapper(torch.nn.modules.module.Module):
116class SimpleSam3DWrapper(nn.Module):
117    def __init__(self, sam, num_classes, freeze_encoder):
118        super().__init__()
119
120        self.sam = sam
121        self.freeze_encoder = freeze_encoder
122        if self.freeze_encoder:
123            for param in self.sam.image_encoder.parameters():
124                param.requires_grad = False
125            self.no_grad = torch.no_grad
126
127        else:
128            self.no_grad = nullcontext
129
130        self.decoders = nn.ModuleList([
131            BasicBlock(in_channels=256, out_channels=128),
132            BasicBlock(in_channels=128, out_channels=64),
133            BasicBlock(in_channels=64, out_channels=32),
134            BasicBlock(in_channels=32, out_channels=16),
135        ])
136        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
137
138    def _apply_image_encoder(self, x, D):
139        encoder_features = []
140        for d in range(D):
141            image = x[:, :, d]
142            feature = self.sam.image_encoder(image)
143            encoder_features.append(feature)
144        encoder_features = torch.stack(encoder_features, 2)
145        return encoder_features
146
147    def forward(
148        self, batched_input: List[Dict[str, Any]], multimask_output: bool
149    ) -> List[Dict[str, torch.Tensor]]:
150        """Predict 3D masks for the current inputs.
151
152        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
153
154        Args:
155            batched_input: A list over input images, each a dictionary with the following keys.L
156                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
157            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
158
159        Returns:
160            A list over input images, where each element is as dictionary with the following keys:
161                'masks': Mask prediction for this object.
162        """
163        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
164
165        B, C, D, H, W = x.shape
166        assert C == 3
167
168        with self.no_grad():
169            features = self._apply_image_encoder(x, D)
170
171        out = features
172        for decoder in self.decoders:
173            out = decoder(out)
174        logits = self.out_conv(out)
175
176        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
177        return outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SimpleSam3DWrapper(sam, num_classes, freeze_encoder)
117    def __init__(self, sam, num_classes, freeze_encoder):
118        super().__init__()
119
120        self.sam = sam
121        self.freeze_encoder = freeze_encoder
122        if self.freeze_encoder:
123            for param in self.sam.image_encoder.parameters():
124                param.requires_grad = False
125            self.no_grad = torch.no_grad
126
127        else:
128            self.no_grad = nullcontext
129
130        self.decoders = nn.ModuleList([
131            BasicBlock(in_channels=256, out_channels=128),
132            BasicBlock(in_channels=128, out_channels=64),
133            BasicBlock(in_channels=64, out_channels=32),
134            BasicBlock(in_channels=32, out_channels=16),
135        ])
136        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
freeze_encoder
decoders
out_conv
def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
147    def forward(
148        self, batched_input: List[Dict[str, Any]], multimask_output: bool
149    ) -> List[Dict[str, torch.Tensor]]:
150        """Predict 3D masks for the current inputs.
151
152        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
153
154        Args:
155            batched_input: A list over input images, each a dictionary with the following keys.L
156                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
157            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
158
159        Returns:
160            A list over input images, where each element is as dictionary with the following keys:
161                'masks': Mask prediction for this object.
162        """
163        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
164
165        B, C, D, H, W = x.shape
166        assert C == 3
167
168        with self.no_grad():
169            features = self._apply_image_encoder(x, D)
170
171        out = features
172        for decoder in self.decoders:
173            out = decoder(out)
174        logits = self.out_conv(out)
175
176        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
177        return outputs

Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Arguments:
  • batched_input: A list over input images, each a dictionary with the following keys.L 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
  • multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:

A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile