micro_sam.models.simple_sam_3d_wrapper
1import os 2from contextlib import nullcontext 3from typing import Any, List, Dict, Union, Optional 4 5import torch 6import torch.nn as nn 7 8from ..util import get_sam_model 9from .peft_sam import LoRASurgery 10 11 12def get_simple_sam_3d_model( 13 device: Union[str, torch.device], 14 n_classes: int, 15 image_size: int, 16 lora_rank: Optional[int] = None, 17 freeze_encoder: bool = False, 18 model_type: str = "vit_b", 19 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 20): 21 if lora_rank is None: 22 peft_kwargs = {} 23 else: 24 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 25 26 _, sam = get_sam_model( 27 model_type=model_type, 28 device=device, 29 checkpoint_path=checkpoint_path, 30 return_sam=True, 31 image_size=image_size, 32 flexible_load_checkpoint=True, 33 peft_kwargs=peft_kwargs, 34 ) 35 36 # Make sure not to freeze the encoder when using LoRA. 37 freeze_encoder_ = freeze_encoder if lora_rank is None else False 38 sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) 39 sam_3d.to(device) 40 41 return sam_3d 42 43 44class BasicBlock(nn.Module): 45 def __init__( 46 self, 47 in_channels, 48 out_channels, 49 kernel_size=(3, 3, 3), 50 stride=(1, 1, 1), 51 padding=(1, 1, 1), 52 bias=True, 53 mode="nearest" 54 ): 55 super().__init__() 56 57 self.conv1 = nn.Sequential( 58 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 59 nn.InstanceNorm3d(out_channels), 60 nn.LeakyReLU() 61 ) 62 63 self.conv2 = nn.Sequential( 64 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 65 nn.InstanceNorm3d(out_channels) 66 ) 67 68 self.downsample = nn.Sequential( 69 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 70 nn.InstanceNorm3d(out_channels) 71 ) 72 73 self.leakyrelu = nn.LeakyReLU() 74 75 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) 76 77 def forward(self, x): 78 residual = self.downsample(x) 79 80 out = self.conv1(x) 81 out = self.conv2(out) 82 out += residual 83 84 out = self.leakyrelu(out) 85 out = self.up(out) 86 return out 87 88 89class SegmentationHead(nn.Sequential): 90 def __init__( 91 self, 92 in_channels, 93 out_channels, 94 kernel_size=(3, 3, 3), 95 stride=(1, 1, 1), 96 padding=(1, 1, 1), 97 bias=True 98 ): 99 super().__init__() 100 101 self.conv_pred = nn.Sequential( 102 nn.Conv3d( 103 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 104 ), 105 nn.InstanceNorm3d(in_channels // 2), 106 nn.LeakyReLU() 107 ) 108 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) 109 110 def forward(self, x): 111 x = self.conv_pred(x) 112 return self.segmentation_head(x) 113 114 115class SimpleSam3DWrapper(nn.Module): 116 def __init__(self, sam, num_classes, freeze_encoder): 117 super().__init__() 118 119 self.sam = sam 120 self.freeze_encoder = freeze_encoder 121 if self.freeze_encoder: 122 for param in self.sam.image_encoder.parameters(): 123 param.requires_grad = False 124 self.no_grad = torch.no_grad 125 126 else: 127 self.no_grad = nullcontext 128 129 self.decoders = nn.ModuleList([ 130 BasicBlock(in_channels=256, out_channels=128), 131 BasicBlock(in_channels=128, out_channels=64), 132 BasicBlock(in_channels=64, out_channels=32), 133 BasicBlock(in_channels=32, out_channels=16), 134 ]) 135 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) 136 137 def _apply_image_encoder(self, x, D): 138 encoder_features = [] 139 for d in range(D): 140 image = x[:, :, d] 141 feature = self.sam.image_encoder(image) 142 encoder_features.append(feature) 143 encoder_features = torch.stack(encoder_features, 2) 144 return encoder_features 145 146 def forward( 147 self, batched_input: List[Dict[str, Any]], multimask_output: bool 148 ) -> List[Dict[str, torch.Tensor]]: 149 """Predict 3D masks for the current inputs. 150 151 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 152 153 Args: 154 batched_input: A list over input images, each a dictionary with the following keys.L 155 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 156 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 157 158 Returns: 159 A list over input images, where each element is as dictionary with the following keys: 160 'masks': Mask prediction for this object. 161 """ 162 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 163 164 B, C, D, H, W = x.shape 165 assert C == 3 166 167 with self.no_grad(): 168 features = self._apply_image_encoder(x, D) 169 170 out = features 171 for decoder in self.decoders: 172 out = decoder(out) 173 logits = self.out_conv(out) 174 175 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 176 return outputs
13def get_simple_sam_3d_model( 14 device: Union[str, torch.device], 15 n_classes: int, 16 image_size: int, 17 lora_rank: Optional[int] = None, 18 freeze_encoder: bool = False, 19 model_type: str = "vit_b", 20 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 21): 22 if lora_rank is None: 23 peft_kwargs = {} 24 else: 25 peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} 26 27 _, sam = get_sam_model( 28 model_type=model_type, 29 device=device, 30 checkpoint_path=checkpoint_path, 31 return_sam=True, 32 image_size=image_size, 33 flexible_load_checkpoint=True, 34 peft_kwargs=peft_kwargs, 35 ) 36 37 # Make sure not to freeze the encoder when using LoRA. 38 freeze_encoder_ = freeze_encoder if lora_rank is None else False 39 sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) 40 sam_3d.to(device) 41 42 return sam_3d
45class BasicBlock(nn.Module): 46 def __init__( 47 self, 48 in_channels, 49 out_channels, 50 kernel_size=(3, 3, 3), 51 stride=(1, 1, 1), 52 padding=(1, 1, 1), 53 bias=True, 54 mode="nearest" 55 ): 56 super().__init__() 57 58 self.conv1 = nn.Sequential( 59 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 60 nn.InstanceNorm3d(out_channels), 61 nn.LeakyReLU() 62 ) 63 64 self.conv2 = nn.Sequential( 65 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 66 nn.InstanceNorm3d(out_channels) 67 ) 68 69 self.downsample = nn.Sequential( 70 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 71 nn.InstanceNorm3d(out_channels) 72 ) 73 74 self.leakyrelu = nn.LeakyReLU() 75 76 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) 77 78 def forward(self, x): 79 residual = self.downsample(x) 80 81 out = self.conv1(x) 82 out = self.conv2(out) 83 out += residual 84 85 out = self.leakyrelu(out) 86 out = self.up(out) 87 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
46 def __init__( 47 self, 48 in_channels, 49 out_channels, 50 kernel_size=(3, 3, 3), 51 stride=(1, 1, 1), 52 padding=(1, 1, 1), 53 bias=True, 54 mode="nearest" 55 ): 56 super().__init__() 57 58 self.conv1 = nn.Sequential( 59 nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 60 nn.InstanceNorm3d(out_channels), 61 nn.LeakyReLU() 62 ) 63 64 self.conv2 = nn.Sequential( 65 nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 66 nn.InstanceNorm3d(out_channels) 67 ) 68 69 self.downsample = nn.Sequential( 70 nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), 71 nn.InstanceNorm3d(out_channels) 72 ) 73 74 self.leakyrelu = nn.LeakyReLU() 75 76 self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
78 def forward(self, x): 79 residual = self.downsample(x) 80 81 out = self.conv1(x) 82 out = self.conv2(out) 83 out += residual 84 85 out = self.leakyrelu(out) 86 out = self.up(out) 87 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
90class SegmentationHead(nn.Sequential): 91 def __init__( 92 self, 93 in_channels, 94 out_channels, 95 kernel_size=(3, 3, 3), 96 stride=(1, 1, 1), 97 padding=(1, 1, 1), 98 bias=True 99 ): 100 super().__init__() 101 102 self.conv_pred = nn.Sequential( 103 nn.Conv3d( 104 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 105 ), 106 nn.InstanceNorm3d(in_channels // 2), 107 nn.LeakyReLU() 108 ) 109 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) 110 111 def forward(self, x): 112 x = self.conv_pred(x) 113 return self.segmentation_head(x)
A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an OrderedDict
of modules can be
passed in. The forward()
method of Sequential
accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a Sequential
provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
Sequential
applies to each of the modules it stores (which are
each a registered submodule of the Sequential
).
What's the difference between a Sequential
and a
torch.nn.ModuleList
? A ModuleList
is exactly what it
sounds like--a list for storing Module
s! On the other hand,
the layers in a Sequential
are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
91 def __init__( 92 self, 93 in_channels, 94 out_channels, 95 kernel_size=(3, 3, 3), 96 stride=(1, 1, 1), 97 padding=(1, 1, 1), 98 bias=True 99 ): 100 super().__init__() 101 102 self.conv_pred = nn.Sequential( 103 nn.Conv3d( 104 in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias 105 ), 106 nn.InstanceNorm3d(in_channels // 2), 107 nn.LeakyReLU() 108 ) 109 self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.container.Sequential
- pop
- append
- insert
- extend
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
116class SimpleSam3DWrapper(nn.Module): 117 def __init__(self, sam, num_classes, freeze_encoder): 118 super().__init__() 119 120 self.sam = sam 121 self.freeze_encoder = freeze_encoder 122 if self.freeze_encoder: 123 for param in self.sam.image_encoder.parameters(): 124 param.requires_grad = False 125 self.no_grad = torch.no_grad 126 127 else: 128 self.no_grad = nullcontext 129 130 self.decoders = nn.ModuleList([ 131 BasicBlock(in_channels=256, out_channels=128), 132 BasicBlock(in_channels=128, out_channels=64), 133 BasicBlock(in_channels=64, out_channels=32), 134 BasicBlock(in_channels=32, out_channels=16), 135 ]) 136 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) 137 138 def _apply_image_encoder(self, x, D): 139 encoder_features = [] 140 for d in range(D): 141 image = x[:, :, d] 142 feature = self.sam.image_encoder(image) 143 encoder_features.append(feature) 144 encoder_features = torch.stack(encoder_features, 2) 145 return encoder_features 146 147 def forward( 148 self, batched_input: List[Dict[str, Any]], multimask_output: bool 149 ) -> List[Dict[str, torch.Tensor]]: 150 """Predict 3D masks for the current inputs. 151 152 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 153 154 Args: 155 batched_input: A list over input images, each a dictionary with the following keys.L 156 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 157 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 158 159 Returns: 160 A list over input images, where each element is as dictionary with the following keys: 161 'masks': Mask prediction for this object. 162 """ 163 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 164 165 B, C, D, H, W = x.shape 166 assert C == 3 167 168 with self.no_grad(): 169 features = self._apply_image_encoder(x, D) 170 171 out = features 172 for decoder in self.decoders: 173 out = decoder(out) 174 logits = self.out_conv(out) 175 176 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 177 return outputs
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
117 def __init__(self, sam, num_classes, freeze_encoder): 118 super().__init__() 119 120 self.sam = sam 121 self.freeze_encoder = freeze_encoder 122 if self.freeze_encoder: 123 for param in self.sam.image_encoder.parameters(): 124 param.requires_grad = False 125 self.no_grad = torch.no_grad 126 127 else: 128 self.no_grad = nullcontext 129 130 self.decoders = nn.ModuleList([ 131 BasicBlock(in_channels=256, out_channels=128), 132 BasicBlock(in_channels=128, out_channels=64), 133 BasicBlock(in_channels=64, out_channels=32), 134 BasicBlock(in_channels=32, out_channels=16), 135 ]) 136 self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
147 def forward( 148 self, batched_input: List[Dict[str, Any]], multimask_output: bool 149 ) -> List[Dict[str, torch.Tensor]]: 150 """Predict 3D masks for the current inputs. 151 152 Unlike original SAM this model only supports automatic segmentation and does not support prompts. 153 154 Args: 155 batched_input: A list over input images, each a dictionary with the following keys.L 156 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 157 multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. 158 159 Returns: 160 A list over input images, where each element is as dictionary with the following keys: 161 'masks': Mask prediction for this object. 162 """ 163 x = torch.stack([inp["image"] for inp in batched_input], dim=0) 164 165 B, C, D, H, W = x.shape 166 assert C == 3 167 168 with self.no_grad(): 169 features = self._apply_image_encoder(x, D) 170 171 out = features 172 for decoder in self.decoders: 173 out = decoder(out) 174 logits = self.out_conv(out) 175 176 outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] 177 return outputs
Predict 3D masks for the current inputs.
Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Arguments:
- batched_input: A list over input images, each a dictionary with the following keys.L 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
- multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
Returns:
A list over input images, where each element is as dictionary with the following keys: 'masks': Mask prediction for this object.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile