micro_sam.models.build_sam

  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3
  4# This source code is licensed under the license found in the
  5# LICENSE file in the root directory of this source tree.
  6# https://github.com/facebookresearch/segment-anything/
  7
  8#
  9# NOTE: This code has been adapted from Segment Anything.
 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py
 11# This is done in favor of exposing some of the model's hard-coded input parameters for:
 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation)
 13#
 14
 15from typing import OrderedDict
 16
 17import torch
 18
 19from functools import partial
 20
 21from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer
 22
 23
 24def _validate_model_type(state: OrderedDict) -> str:
 25    # We compare the 'embed_dim' values in the patch embedding stage.
 26    if "image_encoder.patch_embed.proj.weight" in state:  # for all OG SAM models.
 27        embed_dim_size = state["image_encoder.patch_embed.proj.weight"].shape[0]
 28
 29        # Mapping for SAM models based on 'embed_dim'.
 30        # NOTE: We can make this more flexible to subject this to the 'depth' as well.
 31        embed_dim_combinations = {768: "vit_b", 1024: "vit_l", 1280: "vit_h"}
 32        _provided_model_type = embed_dim_combinations[embed_dim_size]
 33
 34    else:  # for MobileSAM (vit-tiny) models.
 35        _provided_model_type = "vit_t"
 36
 37    return _provided_model_type
 38
 39
 40def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 41    return _build_sam(
 42        encoder_embed_dim=1280,
 43        encoder_depth=32,
 44        encoder_num_heads=16,
 45        encoder_global_attn_indexes=[7, 15, 23, 31],
 46        checkpoint=checkpoint,
 47        num_multimask_outputs=num_multimask_outputs,
 48        image_size=image_size,
 49    )
 50
 51
 52build_sam = build_sam_vit_h
 53
 54
 55def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 56    return _build_sam(
 57        encoder_embed_dim=1024,
 58        encoder_depth=24,
 59        encoder_num_heads=16,
 60        encoder_global_attn_indexes=[5, 11, 17, 23],
 61        checkpoint=checkpoint,
 62        num_multimask_outputs=num_multimask_outputs,
 63        image_size=image_size,
 64    )
 65
 66
 67def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 68    return _build_sam(
 69        encoder_embed_dim=768,
 70        encoder_depth=12,
 71        encoder_num_heads=12,
 72        encoder_global_attn_indexes=[2, 5, 8, 11],
 73        checkpoint=checkpoint,
 74        num_multimask_outputs=num_multimask_outputs,
 75        image_size=image_size,
 76    )
 77
 78
 79sam_model_registry = {
 80    "default": build_sam_vit_h,
 81    "vit_h": build_sam_vit_h,
 82    "vit_l": build_sam_vit_l,
 83    "vit_b": build_sam_vit_b,
 84}
 85
 86
 87def _build_sam(
 88    encoder_embed_dim,
 89    encoder_depth,
 90    encoder_num_heads,
 91    encoder_global_attn_indexes,
 92    checkpoint=None,
 93    num_multimask_outputs=3,
 94    image_size=1024,
 95):
 96    prompt_embed_dim = 256
 97    vit_patch_size = 16
 98    image_embedding_size = image_size // vit_patch_size
 99    sam = Sam(
100        image_encoder=ImageEncoderViT(
101            depth=encoder_depth,
102            embed_dim=encoder_embed_dim,
103            img_size=image_size,
104            mlp_ratio=4,
105            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
106            num_heads=encoder_num_heads,
107            patch_size=vit_patch_size,
108            qkv_bias=True,
109            use_rel_pos=True,
110            global_attn_indexes=encoder_global_attn_indexes,
111            window_size=14,
112            out_chans=prompt_embed_dim,
113        ),
114        prompt_encoder=PromptEncoder(
115            embed_dim=prompt_embed_dim,
116            image_embedding_size=(image_embedding_size, image_embedding_size),
117            input_image_size=(image_size, image_size),
118            mask_in_chans=16,
119        ),
120        mask_decoder=MaskDecoder(
121            num_multimask_outputs=num_multimask_outputs,
122            transformer=TwoWayTransformer(
123                depth=2,
124                embedding_dim=prompt_embed_dim,
125                mlp_dim=2048,
126                num_heads=8,
127            ),
128            transformer_dim=prompt_embed_dim,
129            iou_head_depth=3,
130            iou_head_hidden_dim=256,
131        ),
132        pixel_mean=[123.675, 116.28, 103.53],
133        pixel_std=[58.395, 57.12, 57.375],
134    )
135
136    sam.eval()
137    if checkpoint is not None:
138        with open(checkpoint, "rb") as f:
139            state_dict = torch.load(f)
140        sam.load_state_dict(state_dict)
141
142    return sam
def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
41def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
42    return _build_sam(
43        encoder_embed_dim=1280,
44        encoder_depth=32,
45        encoder_num_heads=16,
46        encoder_global_attn_indexes=[7, 15, 23, 31],
47        checkpoint=checkpoint,
48        num_multimask_outputs=num_multimask_outputs,
49        image_size=image_size,
50    )
def build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
41def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
42    return _build_sam(
43        encoder_embed_dim=1280,
44        encoder_depth=32,
45        encoder_num_heads=16,
46        encoder_global_attn_indexes=[7, 15, 23, 31],
47        checkpoint=checkpoint,
48        num_multimask_outputs=num_multimask_outputs,
49        image_size=image_size,
50    )
def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
56def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
57    return _build_sam(
58        encoder_embed_dim=1024,
59        encoder_depth=24,
60        encoder_num_heads=16,
61        encoder_global_attn_indexes=[5, 11, 17, 23],
62        checkpoint=checkpoint,
63        num_multimask_outputs=num_multimask_outputs,
64        image_size=image_size,
65    )
def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
68def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
69    return _build_sam(
70        encoder_embed_dim=768,
71        encoder_depth=12,
72        encoder_num_heads=12,
73        encoder_global_attn_indexes=[2, 5, 8, 11],
74        checkpoint=checkpoint,
75        num_multimask_outputs=num_multimask_outputs,
76        image_size=image_size,
77    )
sam_model_registry = {'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}