micro_sam.models.build_sam

  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3
  4# This source code is licensed under the license found in the
  5# LICENSE file in the root directory of this source tree.
  6# https://github.com/facebookresearch/segment-anything/
  7
  8#
  9# NOTE: This code has been adapted from Segment Anything.
 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py
 11# This is done in favor of exposing some of the model's hard-coded input parameters for:
 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation)
 13#
 14
 15import torch
 16
 17from functools import partial
 18
 19from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer
 20
 21
 22def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 23    return _build_sam(
 24        encoder_embed_dim=1280,
 25        encoder_depth=32,
 26        encoder_num_heads=16,
 27        encoder_global_attn_indexes=[7, 15, 23, 31],
 28        checkpoint=checkpoint,
 29        num_multimask_outputs=num_multimask_outputs,
 30        image_size=image_size,
 31    )
 32
 33
 34build_sam = build_sam_vit_h
 35
 36
 37def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 38    return _build_sam(
 39        encoder_embed_dim=1024,
 40        encoder_depth=24,
 41        encoder_num_heads=16,
 42        encoder_global_attn_indexes=[5, 11, 17, 23],
 43        checkpoint=checkpoint,
 44        num_multimask_outputs=num_multimask_outputs,
 45        image_size=image_size,
 46    )
 47
 48
 49def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 50    return _build_sam(
 51        encoder_embed_dim=768,
 52        encoder_depth=12,
 53        encoder_num_heads=12,
 54        encoder_global_attn_indexes=[2, 5, 8, 11],
 55        checkpoint=checkpoint,
 56        num_multimask_outputs=num_multimask_outputs,
 57        image_size=image_size,
 58    )
 59
 60
 61sam_model_registry = {
 62    "default": build_sam_vit_h,
 63    "vit_h": build_sam_vit_h,
 64    "vit_l": build_sam_vit_l,
 65    "vit_b": build_sam_vit_b,
 66}
 67
 68
 69def _build_sam(
 70    encoder_embed_dim,
 71    encoder_depth,
 72    encoder_num_heads,
 73    encoder_global_attn_indexes,
 74    checkpoint=None,
 75    num_multimask_outputs=3,
 76    image_size=1024,
 77):
 78    prompt_embed_dim = 256
 79    vit_patch_size = 16
 80    image_embedding_size = image_size // vit_patch_size
 81    sam = Sam(
 82        image_encoder=ImageEncoderViT(
 83            depth=encoder_depth,
 84            embed_dim=encoder_embed_dim,
 85            img_size=image_size,
 86            mlp_ratio=4,
 87            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 88            num_heads=encoder_num_heads,
 89            patch_size=vit_patch_size,
 90            qkv_bias=True,
 91            use_rel_pos=True,
 92            global_attn_indexes=encoder_global_attn_indexes,
 93            window_size=14,
 94            out_chans=prompt_embed_dim,
 95        ),
 96        prompt_encoder=PromptEncoder(
 97            embed_dim=prompt_embed_dim,
 98            image_embedding_size=(image_embedding_size, image_embedding_size),
 99            input_image_size=(image_size, image_size),
100            mask_in_chans=16,
101        ),
102        mask_decoder=MaskDecoder(
103            num_multimask_outputs=num_multimask_outputs,
104            transformer=TwoWayTransformer(
105                depth=2,
106                embedding_dim=prompt_embed_dim,
107                mlp_dim=2048,
108                num_heads=8,
109            ),
110            transformer_dim=prompt_embed_dim,
111            iou_head_depth=3,
112            iou_head_hidden_dim=256,
113        ),
114        pixel_mean=[123.675, 116.28, 103.53],
115        pixel_std=[58.395, 57.12, 57.375],
116    )
117
118    sam.eval()
119    if checkpoint is not None:
120        with open(checkpoint, "rb") as f:
121            state_dict = torch.load(f)
122        sam.load_state_dict(state_dict)
123
124    return sam
def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24    return _build_sam(
25        encoder_embed_dim=1280,
26        encoder_depth=32,
27        encoder_num_heads=16,
28        encoder_global_attn_indexes=[7, 15, 23, 31],
29        checkpoint=checkpoint,
30        num_multimask_outputs=num_multimask_outputs,
31        image_size=image_size,
32    )
def build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24    return _build_sam(
25        encoder_embed_dim=1280,
26        encoder_depth=32,
27        encoder_num_heads=16,
28        encoder_global_attn_indexes=[7, 15, 23, 31],
29        checkpoint=checkpoint,
30        num_multimask_outputs=num_multimask_outputs,
31        image_size=image_size,
32    )
def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
38def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
39    return _build_sam(
40        encoder_embed_dim=1024,
41        encoder_depth=24,
42        encoder_num_heads=16,
43        encoder_global_attn_indexes=[5, 11, 17, 23],
44        checkpoint=checkpoint,
45        num_multimask_outputs=num_multimask_outputs,
46        image_size=image_size,
47    )
def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
50def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
51    return _build_sam(
52        encoder_embed_dim=768,
53        encoder_depth=12,
54        encoder_num_heads=12,
55        encoder_global_attn_indexes=[2, 5, 8, 11],
56        checkpoint=checkpoint,
57        num_multimask_outputs=num_multimask_outputs,
58        image_size=image_size,
59    )
sam_model_registry = {'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}