micro_sam.models.simple_sam_3d_wrapper
1from contextlib import nullcontext 2from typing import Any, List, Dict 3 4import torch 5import torch.nn as nn 6 7from ..util import get_sam_model 8from .peft_sam import LoRASurgery 9 10 11def get_simple_sam_3d_model( 12 device, 13 n_classes, 14 image_size, 15 lora_rank=None, 16 freeze_encoder=False, 17 model_type="vit_b", 18 checkpoint_path=None, 19): 20 if lora_rank is None: 21 peft_kwargs = {} 22 else: 23 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 24 25 _, sam = get_sam_model( 26 model_type=model_type, 27 device=device, 28 checkpoint_path=checkpoint_path, 29 return_sam=True, 30 image_size=image_size, 31 flexible_load_checkpoint=True, 32 peft_kwargs=peft_kwargs, 33 ) 34 35 # Make sure not to freeze the encoder when using LoRA. 36 freeze_encoder_ = freeze_encoder if lora_rank is None else False 37 sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) 38 sam_3d.to(device) 39 return sam_3d 40 41 42class BasicBlock(nn.Module): 43 def __init__( 44 self, 45 in_channels, 46 out_channels, 47 kernel_size=(3, 3, 3), 48 stride=(1, 1, 1), 49 padding=(1, 1, 1), 50 bias=True, 51 mode="nearest" 52 ): 53 super().__init__() 54 55 self.conv1 = nn.Sequential( 56 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 57 nn.InstanceNorm3d(out_channels), 58 nn.LeakyReLU() 59 ) 60 61 self.conv2 = nn.Sequential( 62 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 63 nn.InstanceNorm3d(out_channels) 64 ) 65 66 self.downsample = nn.Sequential( 67 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 68 nn.InstanceNorm3d(out_channels) 69 ) 70 71 self.leakyrelu = nn.LeakyReLU() 72 73 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) 74 75 def forward(self, x): 76 residual = self.downsample(x) 77 78 out = self.conv1(x) 79 out = self.conv2(out) 80 out += residual 81 82 out = self.leakyrelu(out) 83 out = self.up(out) 84 return out 85 86 87class SegmentationHead(nn.Sequential): 88 def __init__( 89 self, 90 in_channels, 91 out_channels, 92 kernel_size=(3, 3, 3), 93 stride=(1, 1, 1), 94 padding=(1, 1, 1), 95 bias=True 96 ): 97 super().__init__() 98 99 self.conv_pred = nn.Sequential( 100 nn.Conv3d( 101 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 102 ), 103 nn.InstanceNorm3d(in_channels // 2), 104 nn.LeakyReLU() 105 ) 106 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) 107 108 def forward(self, x): 109 x = self.conv_pred(x) 110 return self.segmentation_head(x) 111 112 113class SimpleSam3DWrapper(nn.Module): 114 def __init__(self, sam, num_classes, freeze_encoder): 115 super().__init__() 116 117 self.sam = sam 118 self.freeze_encoder = freeze_encoder 119 if self.freeze_encoder: 120 for param in self.sam.image_encoder.parameters(): 121 param.requires_grad = False 122 self.no_grad = torch.no_grad 123 124 else: 125 self.no_grad = nullcontext 126 127 self.decoders = nn.ModuleList([ 128 BasicBlock(in_channels=256, out_channels=128), 129 BasicBlock(in_channels=128, out_channels=64), 130 BasicBlock(in_channels=64, out_channels=32), 131 BasicBlock(in_channels=32, out_channels=16), 132 ]) 133 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) 134 135 def _apply_image_encoder(self, x, D): 136 encoder_features = [] 137 for d in range(D): 138 image = x[:, :, d] 139 feature = self.sam.image_encoder(image) 140 encoder_features.append(feature) 141 encoder_features = torch.stack(encoder_features, 2) 142 return encoder_features 143 144 def forward( 145 self, 146 batched_input: List[Dict[str, Any]], 147 multimask_output: bool 148 ) -> List[Dict[str, torch.Tensor]]: 149 """Predict 3D masks for the current inputs. 150 151 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 152 153 Args: 154 batched_input: A list over input images, each a dictionary with the following keys.L 155 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 156 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 157 158 Returns: 159 A list over input images, where each element is as dictionary with the following keys: 160 'masks': Mask prediction for this object. 161 """ 162 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 163 164 B, C, D, H, W = x.shape 165 assert C == 3 166 167 with self.no_grad(): 168 features = self._apply_image_encoder(x, D) 169 170 out = features 171 for decoder in self.decoders: 172 out = decoder(out) 173 logits = self.out_conv(out) 174 175 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 176 return outputs
12def get_simple_sam_3d_model( 13 device, 14 n_classes, 15 image_size, 16 lora_rank=None, 17 freeze_encoder=False, 18 model_type="vit_b", 19 checkpoint_path=None, 20): 21 if lora_rank is None: 22 peft_kwargs = {} 23 else: 24 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 25 26 _, sam = get_sam_model( 27 model_type=model_type, 28 device=device, 29 checkpoint_path=checkpoint_path, 30 return_sam=True, 31 image_size=image_size, 32 flexible_load_checkpoint=True, 33 peft_kwargs=peft_kwargs, 34 ) 35 36 # Make sure not to freeze the encoder when using LoRA. 37 freeze_encoder_ = freeze_encoder if lora_rank is None else False 38 sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) 39 sam_3d.to(device) 40 return sam_3d
43class BasicBlock(nn.Module): 44 def __init__( 45 self, 46 in_channels, 47 out_channels, 48 kernel_size=(3, 3, 3), 49 stride=(1, 1, 1), 50 padding=(1, 1, 1), 51 bias=True, 52 mode="nearest" 53 ): 54 super().__init__() 55 56 self.conv1 = nn.Sequential( 57 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 58 nn.InstanceNorm3d(out_channels), 59 nn.LeakyReLU() 60 ) 61 62 self.conv2 = nn.Sequential( 63 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 64 nn.InstanceNorm3d(out_channels) 65 ) 66 67 self.downsample = nn.Sequential( 68 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 69 nn.InstanceNorm3d(out_channels) 70 ) 71 72 self.leakyrelu = nn.LeakyReLU() 73 74 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) 75 76 def forward(self, x): 77 residual = self.downsample(x) 78 79 out = self.conv1(x) 80 out = self.conv2(out) 81 out += residual 82 83 out = self.leakyrelu(out) 84 out = self.up(out) 85 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
44 def __init__( 45 self, 46 in_channels, 47 out_channels, 48 kernel_size=(3, 3, 3), 49 stride=(1, 1, 1), 50 padding=(1, 1, 1), 51 bias=True, 52 mode="nearest" 53 ): 54 super().__init__() 55 56 self.conv1 = nn.Sequential( 57 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 58 nn.InstanceNorm3d(out_channels), 59 nn.LeakyReLU() 60 ) 61 62 self.conv2 = nn.Sequential( 63 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 64 nn.InstanceNorm3d(out_channels) 65 ) 66 67 self.downsample = nn.Sequential( 68 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 69 nn.InstanceNorm3d(out_channels) 70 ) 71 72 self.leakyrelu = nn.LeakyReLU() 73 74 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
76 def forward(self, x): 77 residual = self.downsample(x) 78 79 out = self.conv1(x) 80 out = self.conv2(out) 81 out += residual 82 83 out = self.leakyrelu(out) 84 out = self.up(out) 85 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
88class SegmentationHead(nn.Sequential): 89 def __init__( 90 self, 91 in_channels, 92 out_channels, 93 kernel_size=(3, 3, 3), 94 stride=(1, 1, 1), 95 padding=(1, 1, 1), 96 bias=True 97 ): 98 super().__init__() 99 100 self.conv_pred = nn.Sequential( 101 nn.Conv3d( 102 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 103 ), 104 nn.InstanceNorm3d(in_channels // 2), 105 nn.LeakyReLU() 106 ) 107 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) 108 109 def forward(self, x): 110 x = self.conv_pred(x) 111 return self.segmentation_head(x)
A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an OrderedDict
of modules can be
passed in. The forward()
method of Sequential
accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a Sequential
provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
Sequential
applies to each of the modules it stores (which are
each a registered submodule of the Sequential
).
What's the difference between a Sequential
and a
torch.nn.ModuleList
? A ModuleList
is exactly what it
sounds like--a list for storing Module
s! On the other hand,
the layers in a Sequential
are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
89 def __init__( 90 self, 91 in_channels, 92 out_channels, 93 kernel_size=(3, 3, 3), 94 stride=(1, 1, 1), 95 padding=(1, 1, 1), 96 bias=True 97 ): 98 super().__init__() 99 100 self.conv_pred = nn.Sequential( 101 nn.Conv3d( 102 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 103 ), 104 nn.InstanceNorm3d(in_channels // 2), 105 nn.LeakyReLU() 106 ) 107 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.container.Sequential
- pop
- append
- insert
- extend
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
114class SimpleSam3DWrapper(nn.Module): 115 def __init__(self, sam, num_classes, freeze_encoder): 116 super().__init__() 117 118 self.sam = sam 119 self.freeze_encoder = freeze_encoder 120 if self.freeze_encoder: 121 for param in self.sam.image_encoder.parameters(): 122 param.requires_grad = False 123 self.no_grad = torch.no_grad 124 125 else: 126 self.no_grad = nullcontext 127 128 self.decoders = nn.ModuleList([ 129 BasicBlock(in_channels=256, out_channels=128), 130 BasicBlock(in_channels=128, out_channels=64), 131 BasicBlock(in_channels=64, out_channels=32), 132 BasicBlock(in_channels=32, out_channels=16), 133 ]) 134 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) 135 136 def _apply_image_encoder(self, x, D): 137 encoder_features = [] 138 for d in range(D): 139 image = x[:, :, d] 140 feature = self.sam.image_encoder(image) 141 encoder_features.append(feature) 142 encoder_features = torch.stack(encoder_features, 2) 143 return encoder_features 144 145 def forward( 146 self, 147 batched_input: List[Dict[str, Any]], 148 multimask_output: bool 149 ) -> List[Dict[str, torch.Tensor]]: 150 """Predict 3D masks for the current inputs. 151 152 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 153 154 Args: 155 batched_input: A list over input images, each a dictionary with the following keys.L 156 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 157 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 158 159 Returns: 160 A list over input images, where each element is as dictionary with the following keys: 161 'masks': Mask prediction for this object. 162 """ 163 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 164 165 B, C, D, H, W = x.shape 166 assert C == 3 167 168 with self.no_grad(): 169 features = self._apply_image_encoder(x, D) 170 171 out = features 172 for decoder in self.decoders: 173 out = decoder(out) 174 logits = self.out_conv(out) 175 176 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 177 return outputs
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
115 def __init__(self, sam, num_classes, freeze_encoder): 116 super().__init__() 117 118 self.sam = sam 119 self.freeze_encoder = freeze_encoder 120 if self.freeze_encoder: 121 for param in self.sam.image_encoder.parameters(): 122 param.requires_grad = False 123 self.no_grad = torch.no_grad 124 125 else: 126 self.no_grad = nullcontext 127 128 self.decoders = nn.ModuleList([ 129 BasicBlock(in_channels=256, out_channels=128), 130 BasicBlock(in_channels=128, out_channels=64), 131 BasicBlock(in_channels=64, out_channels=32), 132 BasicBlock(in_channels=32, out_channels=16), 133 ]) 134 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
145 def forward( 146 self, 147 batched_input: List[Dict[str, Any]], 148 multimask_output: bool 149 ) -> List[Dict[str, torch.Tensor]]: 150 """Predict 3D masks for the current inputs. 151 152 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 153 154 Args: 155 batched_input: A list over input images, each a dictionary with the following keys.L 156 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 157 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 158 159 Returns: 160 A list over input images, where each element is as dictionary with the following keys: 161 'masks': Mask prediction for this object. 162 """ 163 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 164 165 B, C, D, H, W = x.shape 166 assert C == 3 167 168 with self.no_grad(): 169 features = self._apply_image_encoder(x, D) 170 171 out = features 172 for decoder in self.decoders: 173 out = decoder(out) 174 logits = self.out_conv(out) 175 176 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 177 return outputs
Predict 3D masks for the current inputs.
Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Arguments:
- batched_input: A list over input images, each a dictionary with the following keys.L 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
- multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:
A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile