micro_sam.models.build_sam
1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3 4# This source code is licensed under the license found in the 5# LICENSE file in the root directory of this source tree. 6# https://github.com/facebookresearch/segment-anything/ 7 8# 9# NOTE: This code has been adapted from Segment Anything. 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py 11# This is done in favor of exposing some of the model's hard-coded input parameters for: 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation) 13# 14 15 16import torch 17 18from functools import partial 19 20from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer 21 22 23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 24 return _build_sam( 25 encoder_embed_dim=1280, 26 encoder_depth=32, 27 encoder_num_heads=16, 28 encoder_global_attn_indexes=[7, 15, 23, 31], 29 checkpoint=checkpoint, 30 num_multimask_outputs=num_multimask_outputs, 31 image_size=image_size, 32 ) 33 34 35build_sam = build_sam_vit_h 36 37 38def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 39 return _build_sam( 40 encoder_embed_dim=1024, 41 encoder_depth=24, 42 encoder_num_heads=16, 43 encoder_global_attn_indexes=[5, 11, 17, 23], 44 checkpoint=checkpoint, 45 num_multimask_outputs=num_multimask_outputs, 46 image_size=image_size, 47 ) 48 49 50def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 51 return _build_sam( 52 encoder_embed_dim=768, 53 encoder_depth=12, 54 encoder_num_heads=12, 55 encoder_global_attn_indexes=[2, 5, 8, 11], 56 checkpoint=checkpoint, 57 num_multimask_outputs=num_multimask_outputs, 58 image_size=image_size, 59 ) 60 61 62sam_model_registry = { 63 "default": build_sam_vit_h, 64 "vit_h": build_sam_vit_h, 65 "vit_l": build_sam_vit_l, 66 "vit_b": build_sam_vit_b, 67} 68 69 70def _build_sam( 71 encoder_embed_dim, 72 encoder_depth, 73 encoder_num_heads, 74 encoder_global_attn_indexes, 75 checkpoint=None, 76 num_multimask_outputs=3, 77 image_size=1024, 78): 79 prompt_embed_dim = 256 80 vit_patch_size = 16 81 image_embedding_size = image_size // vit_patch_size 82 sam = Sam( 83 image_encoder=ImageEncoderViT( 84 depth=encoder_depth, 85 embed_dim=encoder_embed_dim, 86 img_size=image_size, 87 mlp_ratio=4, 88 norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 89 num_heads=encoder_num_heads, 90 patch_size=vit_patch_size, 91 qkv_bias=True, 92 use_rel_pos=True, 93 global_attn_indexes=encoder_global_attn_indexes, 94 window_size=14, 95 out_chans=prompt_embed_dim, 96 ), 97 prompt_encoder=PromptEncoder( 98 embed_dim=prompt_embed_dim, 99 image_embedding_size=(image_embedding_size, image_embedding_size), 100 input_image_size=(image_size, image_size), 101 mask_in_chans=16, 102 ), 103 mask_decoder=MaskDecoder( 104 num_multimask_outputs=num_multimask_outputs, 105 transformer=TwoWayTransformer( 106 depth=2, 107 embedding_dim=prompt_embed_dim, 108 mlp_dim=2048, 109 num_heads=8, 110 ), 111 transformer_dim=prompt_embed_dim, 112 iou_head_depth=3, 113 iou_head_hidden_dim=256, 114 ), 115 pixel_mean=[123.675, 116.28, 103.53], 116 pixel_std=[58.395, 57.12, 57.375], 117 ) 118 sam.eval() 119 if checkpoint is not None: 120 with open(checkpoint, "rb") as f: 121 state_dict = torch.load(f) 122 sam.load_state_dict(state_dict) 123 return sam
def
build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 25 return _build_sam( 26 encoder_embed_dim=1280, 27 encoder_depth=32, 28 encoder_num_heads=16, 29 encoder_global_attn_indexes=[7, 15, 23, 31], 30 checkpoint=checkpoint, 31 num_multimask_outputs=num_multimask_outputs, 32 image_size=image_size, 33 )
def
build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): 25 return _build_sam( 26 encoder_embed_dim=1280, 27 encoder_depth=32, 28 encoder_num_heads=16, 29 encoder_global_attn_indexes=[7, 15, 23, 31], 30 checkpoint=checkpoint, 31 num_multimask_outputs=num_multimask_outputs, 32 image_size=image_size, 33 )
def
build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
39def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): 40 return _build_sam( 41 encoder_embed_dim=1024, 42 encoder_depth=24, 43 encoder_num_heads=16, 44 encoder_global_attn_indexes=[5, 11, 17, 23], 45 checkpoint=checkpoint, 46 num_multimask_outputs=num_multimask_outputs, 47 image_size=image_size, 48 )
def
build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
51def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): 52 return _build_sam( 53 encoder_embed_dim=768, 54 encoder_depth=12, 55 encoder_num_heads=12, 56 encoder_global_attn_indexes=[2, 5, 8, 11], 57 checkpoint=checkpoint, 58 num_multimask_outputs=num_multimask_outputs, 59 image_size=image_size, 60 )
sam_model_registry =
{'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}