micro_sam.models.build_sam
1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3 4# This source code is licensed under the license found in the 5# LICENSE file in the root directory of this source tree. 6# https://github.com/facebookresearch/segment-anything/ 7 8# 9# NOTE: This code has been adapted from Segment Anything. 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py 11# This is done in favor of exposing some of the model's hard-coded input parameters for: 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation) 13# 14 15import torch 16 17from functools import partial 18 19from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer 20 21 22def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 23 return _build_sam( 24 encoder_embed_dim=1280, 25 encoder_depth=32, 26 encoder_num_heads=16, 27 encoder_global_attn_indexes=[7, 15, 23, 31], 28 checkpoint=checkpoint, 29 num_multimask_outputs=num_multimask_outputs, 30 image_size=image_size, 31 ) 32 33 34build_sam = build_sam_vit_h 35 36 37def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 38 return _build_sam( 39 encoder_embed_dim=1024, 40 encoder_depth=24, 41 encoder_num_heads=16, 42 encoder_global_attn_indexes=[5, 11, 17, 23], 43 checkpoint=checkpoint, 44 num_multimask_outputs=num_multimask_outputs, 45 image_size=image_size, 46 ) 47 48 49def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 50 return _build_sam( 51 encoder_embed_dim=768, 52 encoder_depth=12, 53 encoder_num_heads=12, 54 encoder_global_attn_indexes=[2, 5, 8, 11], 55 checkpoint=checkpoint, 56 num_multimask_outputs=num_multimask_outputs, 57 image_size=image_size, 58 ) 59 60 61sam_model_registry = { 62 "default": build_sam_vit_h, 63 "vit_h": build_sam_vit_h, 64 "vit_l": build_sam_vit_l, 65 "vit_b": build_sam_vit_b, 66} 67 68 69def _build_sam( 70 encoder_embed_dim, 71 encoder_depth, 72 encoder_num_heads, 73 encoder_global_attn_indexes, 74 checkpoint=None, 75 num_multimask_outputs=3, 76 image_size=1024, 77): 78 prompt_embed_dim = 256 79 vit_patch_size = 16 80 image_embedding_size = image_size // vit_patch_size 81 sam = Sam( 82 image_encoder=ImageEncoderViT( 83 depth=encoder_depth, 84 embed_dim=encoder_embed_dim, 85 img_size=image_size, 86 mlp_ratio=4, 87 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 88 num_heads=encoder_num_heads, 89 patch_size=vit_patch_size, 90 qkv_bias=True, 91 use_rel_pos=True, 92 global_attn_indexes=encoder_global_attn_indexes, 93 window_size=14, 94 out_chans=prompt_embed_dim, 95 ), 96 prompt_encoder=PromptEncoder( 97 embed_dim=prompt_embed_dim, 98 image_embedding_size=(image_embedding_size, image_embedding_size), 99 input_image_size=(image_size, image_size), 100 mask_in_chans=16, 101 ), 102 mask_decoder=MaskDecoder( 103 num_multimask_outputs=num_multimask_outputs, 104 transformer=TwoWayTransformer( 105 depth=2, 106 embedding_dim=prompt_embed_dim, 107 mlp_dim=2048, 108 num_heads=8, 109 ), 110 transformer_dim=prompt_embed_dim, 111 iou_head_depth=3, 112 iou_head_hidden_dim=256, 113 ), 114 pixel_mean=[123.675, 116.28, 103.53], 115 pixel_std=[58.395, 57.12, 57.375], 116 ) 117 118 sam.eval() 119 if checkpoint is not None: 120 with open(checkpoint, "rb") as f: 121 state_dict = torch.load(f) 122 sam.load_state_dict(state_dict) 123 124 return sam
def
build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 24 return _build_sam( 25 encoder_embed_dim=1280, 26 encoder_depth=32, 27 encoder_num_heads=16, 28 encoder_global_attn_indexes=[7, 15, 23, 31], 29 checkpoint=checkpoint, 30 num_multimask_outputs=num_multimask_outputs, 31 image_size=image_size, 32 )
def
build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 24 return _build_sam( 25 encoder_embed_dim=1280, 26 encoder_depth=32, 27 encoder_num_heads=16, 28 encoder_global_attn_indexes=[7, 15, 23, 31], 29 checkpoint=checkpoint, 30 num_multimask_outputs=num_multimask_outputs, 31 image_size=image_size, 32 )
def
build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
38def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 39 return _build_sam( 40 encoder_embed_dim=1024, 41 encoder_depth=24, 42 encoder_num_heads=16, 43 encoder_global_attn_indexes=[5, 11, 17, 23], 44 checkpoint=checkpoint, 45 num_multimask_outputs=num_multimask_outputs, 46 image_size=image_size, 47 )
def
build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
50def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 51 return _build_sam( 52 encoder_embed_dim=768, 53 encoder_depth=12, 54 encoder_num_heads=12, 55 encoder_global_attn_indexes=[2, 5, 8, 11], 56 checkpoint=checkpoint, 57 num_multimask_outputs=num_multimask_outputs, 58 image_size=image_size, 59 )
sam_model_registry =
{'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}