micro_sam.models.build_sam
1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3 4# This source code is licensed under the license found in the 5# LICENSE file in the root directory of this source tree. 6# https://github.com/facebookresearch/segment-anything/ 7 8# 9# NOTE: This code has been adapted from Segment Anything. 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py 11# This is done in favor of exposing some of the model's hard-coded input parameters for: 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation) 13# 14 15from typing import OrderedDict 16 17import torch 18 19from functools import partial 20 21from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer 22 23 24def _validate_model_type(state: OrderedDict) -> str: 25 # We compare the 'embed_dim' values in the patch embedding stage. 26 if "image_encoder.patch_embed.proj.weight" in state: # for all OG SAM models. 27 embed_dim_size = state["image_encoder.patch_embed.proj.weight"].shape[0] 28 29 # Mapping for SAM models based on 'embed_dim'. 30 # NOTE: We can make this more flexible to subject this to the 'depth' as well. 31 embed_dim_combinations = {768: "vit_b", 1024: "vit_l", 1280: "vit_h"} 32 _provided_model_type = embed_dim_combinations[embed_dim_size] 33 34 else: # for MobileSAM (vit-tiny) models. 35 _provided_model_type = "vit_t" 36 37 return _provided_model_type 38 39 40def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 41 return _build_sam( 42 encoder_embed_dim=1280, 43 encoder_depth=32, 44 encoder_num_heads=16, 45 encoder_global_attn_indexes=[7, 15, 23, 31], 46 checkpoint=checkpoint, 47 num_multimask_outputs=num_multimask_outputs, 48 image_size=image_size, 49 ) 50 51 52build_sam = build_sam_vit_h 53 54 55def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 56 return _build_sam( 57 encoder_embed_dim=1024, 58 encoder_depth=24, 59 encoder_num_heads=16, 60 encoder_global_attn_indexes=[5, 11, 17, 23], 61 checkpoint=checkpoint, 62 num_multimask_outputs=num_multimask_outputs, 63 image_size=image_size, 64 ) 65 66 67def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 68 return _build_sam( 69 encoder_embed_dim=768, 70 encoder_depth=12, 71 encoder_num_heads=12, 72 encoder_global_attn_indexes=[2, 5, 8, 11], 73 checkpoint=checkpoint, 74 num_multimask_outputs=num_multimask_outputs, 75 image_size=image_size, 76 ) 77 78 79sam_model_registry = { 80 "default": build_sam_vit_h, 81 "vit_h": build_sam_vit_h, 82 "vit_l": build_sam_vit_l, 83 "vit_b": build_sam_vit_b, 84} 85 86 87def _build_sam( 88 encoder_embed_dim, 89 encoder_depth, 90 encoder_num_heads, 91 encoder_global_attn_indexes, 92 checkpoint=None, 93 num_multimask_outputs=3, 94 image_size=1024, 95): 96 prompt_embed_dim = 256 97 vit_patch_size = 16 98 image_embedding_size = image_size // vit_patch_size 99 sam = Sam( 100 image_encoder=ImageEncoderViT( 101 depth=encoder_depth, 102 embed_dim=encoder_embed_dim, 103 img_size=image_size, 104 mlp_ratio=4, 105 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 106 num_heads=encoder_num_heads, 107 patch_size=vit_patch_size, 108 qkv_bias=True, 109 use_rel_pos=True, 110 global_attn_indexes=encoder_global_attn_indexes, 111 window_size=14, 112 out_chans=prompt_embed_dim, 113 ), 114 prompt_encoder=PromptEncoder( 115 embed_dim=prompt_embed_dim, 116 image_embedding_size=(image_embedding_size, image_embedding_size), 117 input_image_size=(image_size, image_size), 118 mask_in_chans=16, 119 ), 120 mask_decoder=MaskDecoder( 121 num_multimask_outputs=num_multimask_outputs, 122 transformer=TwoWayTransformer( 123 depth=2, 124 embedding_dim=prompt_embed_dim, 125 mlp_dim=2048, 126 num_heads=8, 127 ), 128 transformer_dim=prompt_embed_dim, 129 iou_head_depth=3, 130 iou_head_hidden_dim=256, 131 ), 132 pixel_mean=[123.675, 116.28, 103.53], 133 pixel_std=[58.395, 57.12, 57.375], 134 ) 135 136 sam.eval() 137 if checkpoint is not None: 138 with open(checkpoint, "rb") as f: 139 state_dict = torch.load(f) 140 sam.load_state_dict(state_dict) 141 142 return sam
def
build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
41def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 42 return _build_sam( 43 encoder_embed_dim=1280, 44 encoder_depth=32, 45 encoder_num_heads=16, 46 encoder_global_attn_indexes=[7, 15, 23, 31], 47 checkpoint=checkpoint, 48 num_multimask_outputs=num_multimask_outputs, 49 image_size=image_size, 50 )
def
build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
41def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 42 return _build_sam( 43 encoder_embed_dim=1280, 44 encoder_depth=32, 45 encoder_num_heads=16, 46 encoder_global_attn_indexes=[7, 15, 23, 31], 47 checkpoint=checkpoint, 48 num_multimask_outputs=num_multimask_outputs, 49 image_size=image_size, 50 )
def
build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
56def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 57 return _build_sam( 58 encoder_embed_dim=1024, 59 encoder_depth=24, 60 encoder_num_heads=16, 61 encoder_global_attn_indexes=[5, 11, 17, 23], 62 checkpoint=checkpoint, 63 num_multimask_outputs=num_multimask_outputs, 64 image_size=image_size, 65 )
def
build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
68def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 69 return _build_sam( 70 encoder_embed_dim=768, 71 encoder_depth=12, 72 encoder_num_heads=12, 73 encoder_global_attn_indexes=[2, 5, 8, 11], 74 checkpoint=checkpoint, 75 num_multimask_outputs=num_multimask_outputs, 76 image_size=image_size, 77 )
sam_model_registry =
{'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}