micro_sam.models.simple_sam_3d_wrapper

  1from contextlib import nullcontext
  2from typing import Any, List, Dict
  3
  4import torch
  5import torch.nn as nn
  6
  7from ..util import get_sam_model
  8from .peft_sam import LoRASurgery
  9
 10
 11def get_simple_sam_3d_model(
 12    device,
 13    n_classes,
 14    image_size,
 15    lora_rank=None,
 16    freeze_encoder=False,
 17    model_type="vit_b",
 18    checkpoint_path=None,
 19):
 20    if lora_rank is None:
 21        peft_kwargs = {}
 22    else:
 23        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
 24
 25    _, sam = get_sam_model(
 26        model_type=model_type,
 27        device=device,
 28        checkpoint_path=checkpoint_path,
 29        return_sam=True,
 30        image_size=image_size,
 31        flexible_load_checkpoint=True,
 32        peft_kwargs=peft_kwargs,
 33    )
 34
 35    # Make sure not to freeze the encoder when using LoRA.
 36    freeze_encoder_ = freeze_encoder if lora_rank is None else False
 37    sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_)
 38    sam_3d.to(device)
 39    return sam_3d
 40
 41
 42class BasicBlock(nn.Module):
 43    def __init__(
 44        self,
 45        in_channels,
 46        out_channels,
 47        kernel_size=(3, 3, 3),
 48        stride=(1, 1, 1),
 49        padding=(1, 1, 1),
 50        bias=True,
 51        mode="nearest"
 52    ):
 53        super().__init__()
 54
 55        self.conv1 = nn.Sequential(
 56            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
 57            nn.InstanceNorm3d(out_channels),
 58            nn.LeakyReLU()
 59        )
 60
 61        self.conv2 = nn.Sequential(
 62            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
 63            nn.InstanceNorm3d(out_channels)
 64        )
 65
 66        self.downsample = nn.Sequential(
 67            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
 68            nn.InstanceNorm3d(out_channels)
 69        )
 70
 71        self.leakyrelu = nn.LeakyReLU()
 72
 73        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
 74
 75    def forward(self, x):
 76        residual = self.downsample(x)
 77
 78        out = self.conv1(x)
 79        out = self.conv2(out)
 80        out += residual
 81
 82        out = self.leakyrelu(out)
 83        out = self.up(out)
 84        return out
 85
 86
 87class SegmentationHead(nn.Sequential):
 88    def __init__(
 89        self,
 90        in_channels,
 91        out_channels,
 92        kernel_size=(3, 3, 3),
 93        stride=(1, 1, 1),
 94        padding=(1, 1, 1),
 95        bias=True
 96    ):
 97        super().__init__()
 98
 99        self.conv_pred = nn.Sequential(
100            nn.Conv3d(
101                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
102            ),
103            nn.InstanceNorm3d(in_channels // 2),
104            nn.LeakyReLU()
105        )
106        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
107
108    def forward(self, x):
109        x = self.conv_pred(x)
110        return self.segmentation_head(x)
111
112
113class SimpleSam3DWrapper(nn.Module):
114    def __init__(self, sam, num_classes, freeze_encoder):
115        super().__init__()
116
117        self.sam = sam
118        self.freeze_encoder = freeze_encoder
119        if self.freeze_encoder:
120            for param in self.sam.image_encoder.parameters():
121                param.requires_grad = False
122            self.no_grad = torch.no_grad
123
124        else:
125            self.no_grad = nullcontext
126
127        self.decoders = nn.ModuleList([
128            BasicBlock(in_channels=256, out_channels=128),
129            BasicBlock(in_channels=128, out_channels=64),
130            BasicBlock(in_channels=64, out_channels=32),
131            BasicBlock(in_channels=32, out_channels=16),
132        ])
133        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
134
135    def _apply_image_encoder(self, x, D):
136        encoder_features = []
137        for d in range(D):
138            image = x[:, :, d]
139            feature = self.sam.image_encoder(image)
140            encoder_features.append(feature)
141        encoder_features = torch.stack(encoder_features, 2)
142        return encoder_features
143
144    def forward(
145        self,
146        batched_input: List[Dict[str, Any]],
147        multimask_output: bool
148    ) -> List[Dict[str, torch.Tensor]]:
149        """Predict 3D masks for the current inputs.
150
151        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
152
153        Args:
154            batched_input: A list over input images, each a dictionary with the following keys.L
155                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
156            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
157
158        Returns:
159            A list over input images, where each element is as dictionary with the following keys:
160                'masks': Mask prediction for this object.
161        """
162        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
163
164        B, C, D, H, W = x.shape
165        assert C == 3
166
167        with self.no_grad():
168            features = self._apply_image_encoder(x, D)
169
170        out = features
171        for decoder in self.decoders:
172            out = decoder(out)
173        logits = self.out_conv(out)
174
175        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
176        return outputs
def get_simple_sam_3d_model( device, n_classes, image_size, lora_rank=None, freeze_encoder=False, model_type='vit_b', checkpoint_path=None):
12def get_simple_sam_3d_model(
13    device,
14    n_classes,
15    image_size,
16    lora_rank=None,
17    freeze_encoder=False,
18    model_type="vit_b",
19    checkpoint_path=None,
20):
21    if lora_rank is None:
22        peft_kwargs = {}
23    else:
24        peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
25
26    _, sam = get_sam_model(
27        model_type=model_type,
28        device=device,
29        checkpoint_path=checkpoint_path,
30        return_sam=True,
31        image_size=image_size,
32        flexible_load_checkpoint=True,
33        peft_kwargs=peft_kwargs,
34    )
35
36    # Make sure not to freeze the encoder when using LoRA.
37    freeze_encoder_ = freeze_encoder if lora_rank is None else False
38    sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_)
39    sam_3d.to(device)
40    return sam_3d
class BasicBlock(torch.nn.modules.module.Module):
43class BasicBlock(nn.Module):
44    def __init__(
45        self,
46        in_channels,
47        out_channels,
48        kernel_size=(3, 3, 3),
49        stride=(1, 1, 1),
50        padding=(1, 1, 1),
51        bias=True,
52        mode="nearest"
53    ):
54        super().__init__()
55
56        self.conv1 = nn.Sequential(
57            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
58            nn.InstanceNorm3d(out_channels),
59            nn.LeakyReLU()
60        )
61
62        self.conv2 = nn.Sequential(
63            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
64            nn.InstanceNorm3d(out_channels)
65        )
66
67        self.downsample = nn.Sequential(
68            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
69            nn.InstanceNorm3d(out_channels)
70        )
71
72        self.leakyrelu = nn.LeakyReLU()
73
74        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
75
76    def forward(self, x):
77        residual = self.downsample(x)
78
79        out = self.conv1(x)
80        out = self.conv2(out)
81        out += residual
82
83        out = self.leakyrelu(out)
84        out = self.up(out)
85        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

BasicBlock( in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=True, mode='nearest')
44    def __init__(
45        self,
46        in_channels,
47        out_channels,
48        kernel_size=(3, 3, 3),
49        stride=(1, 1, 1),
50        padding=(1, 1, 1),
51        bias=True,
52        mode="nearest"
53    ):
54        super().__init__()
55
56        self.conv1 = nn.Sequential(
57            nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
58            nn.InstanceNorm3d(out_channels),
59            nn.LeakyReLU()
60        )
61
62        self.conv2 = nn.Sequential(
63            nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
64            nn.InstanceNorm3d(out_channels)
65        )
66
67        self.downsample = nn.Sequential(
68            nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias),
69            nn.InstanceNorm3d(out_channels)
70        )
71
72        self.leakyrelu = nn.LeakyReLU()
73
74        self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv1
conv2
downsample
leakyrelu
up
def forward(self, x):
76    def forward(self, x):
77        residual = self.downsample(x)
78
79        out = self.conv1(x)
80        out = self.conv2(out)
81        out += residual
82
83        out = self.leakyrelu(out)
84        out = self.up(out)
85        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SegmentationHead(torch.nn.modules.container.Sequential):
 88class SegmentationHead(nn.Sequential):
 89    def __init__(
 90        self,
 91        in_channels,
 92        out_channels,
 93        kernel_size=(3, 3, 3),
 94        stride=(1, 1, 1),
 95        padding=(1, 1, 1),
 96        bias=True
 97    ):
 98        super().__init__()
 99
100        self.conv_pred = nn.Sequential(
101            nn.Conv3d(
102                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
103            ),
104            nn.InstanceNorm3d(in_channels // 2),
105            nn.LeakyReLU()
106        )
107        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
108
109    def forward(self, x):
110        x = self.conv_pred(x)
111        return self.segmentation_head(x)

A sequential container.

Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict of modules can be passed in. The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then "chains" outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

The value a Sequential provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential applies to each of the modules it stores (which are each a registered submodule of the Sequential).

What's the difference between a Sequential and a torch.nn.ModuleList? A ModuleList is exactly what it sounds like--a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
SegmentationHead( in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=True)
 89    def __init__(
 90        self,
 91        in_channels,
 92        out_channels,
 93        kernel_size=(3, 3, 3),
 94        stride=(1, 1, 1),
 95        padding=(1, 1, 1),
 96        bias=True
 97    ):
 98        super().__init__()
 99
100        self.conv_pred = nn.Sequential(
101            nn.Conv3d(
102                in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
103            ),
104            nn.InstanceNorm3d(in_channels // 2),
105            nn.LeakyReLU()
106        )
107        self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_pred
segmentation_head
def forward(self, x):
109    def forward(self, x):
110        x = self.conv_pred(x)
111        return self.segmentation_head(x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.container.Sequential
pop
append
insert
extend
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class SimpleSam3DWrapper(torch.nn.modules.module.Module):
114class SimpleSam3DWrapper(nn.Module):
115    def __init__(self, sam, num_classes, freeze_encoder):
116        super().__init__()
117
118        self.sam = sam
119        self.freeze_encoder = freeze_encoder
120        if self.freeze_encoder:
121            for param in self.sam.image_encoder.parameters():
122                param.requires_grad = False
123            self.no_grad = torch.no_grad
124
125        else:
126            self.no_grad = nullcontext
127
128        self.decoders = nn.ModuleList([
129            BasicBlock(in_channels=256, out_channels=128),
130            BasicBlock(in_channels=128, out_channels=64),
131            BasicBlock(in_channels=64, out_channels=32),
132            BasicBlock(in_channels=32, out_channels=16),
133        ])
134        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
135
136    def _apply_image_encoder(self, x, D):
137        encoder_features = []
138        for d in range(D):
139            image = x[:, :, d]
140            feature = self.sam.image_encoder(image)
141            encoder_features.append(feature)
142        encoder_features = torch.stack(encoder_features, 2)
143        return encoder_features
144
145    def forward(
146        self,
147        batched_input: List[Dict[str, Any]],
148        multimask_output: bool
149    ) -> List[Dict[str, torch.Tensor]]:
150        """Predict 3D masks for the current inputs.
151
152        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
153
154        Args:
155            batched_input: A list over input images, each a dictionary with the following keys.L
156                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
157            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
158
159        Returns:
160            A list over input images, where each element is as dictionary with the following keys:
161                'masks': Mask prediction for this object.
162        """
163        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
164
165        B, C, D, H, W = x.shape
166        assert C == 3
167
168        with self.no_grad():
169            features = self._apply_image_encoder(x, D)
170
171        out = features
172        for decoder in self.decoders:
173            out = decoder(out)
174        logits = self.out_conv(out)
175
176        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
177        return outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

SimpleSam3DWrapper(sam, num_classes, freeze_encoder)
115    def __init__(self, sam, num_classes, freeze_encoder):
116        super().__init__()
117
118        self.sam = sam
119        self.freeze_encoder = freeze_encoder
120        if self.freeze_encoder:
121            for param in self.sam.image_encoder.parameters():
122                param.requires_grad = False
123            self.no_grad = torch.no_grad
124
125        else:
126            self.no_grad = nullcontext
127
128        self.decoders = nn.ModuleList([
129            BasicBlock(in_channels=256, out_channels=128),
130            BasicBlock(in_channels=128, out_channels=64),
131            BasicBlock(in_channels=64, out_channels=32),
132            BasicBlock(in_channels=32, out_channels=16),
133        ])
134        self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sam
freeze_encoder
decoders
out_conv
def forward( self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
145    def forward(
146        self,
147        batched_input: List[Dict[str, Any]],
148        multimask_output: bool
149    ) -> List[Dict[str, torch.Tensor]]:
150        """Predict 3D masks for the current inputs.
151
152        Unlike original SAM this model only supports automatic segmentation and does not support prompts.
153
154        Args:
155            batched_input: A list over input images, each a dictionary with the following keys.L
156                'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
157            multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
158
159        Returns:
160            A list over input images, where each element is as dictionary with the following keys:
161                'masks': Mask prediction for this object.
162        """
163        x = torch.stack([inp["image"] for inp in batched_input], dim=0)
164
165        B, C, D, H, W = x.shape
166        assert C == 3
167
168        with self.no_grad():
169            features = self._apply_image_encoder(x, D)
170
171        out = features
172        for decoder in self.decoders:
173            out = decoder(out)
174        logits = self.out_conv(out)
175
176        outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
177        return outputs

Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Arguments:
  • batched_input: A list over input images, each a dictionary with the following keys.L 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
  • multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:

A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile