micro_sam.models.build_sam

  1# Copyright (c) Meta Platforms, Inc. and affiliates.
  2# All rights reserved.
  3
  4# This source code is licensed under the license found in the
  5# LICENSE file in the root directory of this source tree.
  6# https://github.com/facebookresearch/segment-anything/
  7
  8#
  9# NOTE: This code has been adapted from Segment Anything.
 10# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py
 11# This is done in favor of exposing some of the model's hard-coded input parameters for:
 12# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation)
 13#
 14
 15
 16import torch
 17
 18from functools import partial
 19
 20from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer
 21
 22
 23def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 24    return _build_sam(
 25        encoder_embed_dim=1280,
 26        encoder_depth=32,
 27        encoder_num_heads=16,
 28        encoder_global_attn_indexes=[7, 15, 23, 31],
 29        checkpoint=checkpoint,
 30        num_multimask_outputs=num_multimask_outputs,
 31        image_size=image_size,
 32    )
 33
 34
 35build_sam = build_sam_vit_h
 36
 37
 38def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 39    return _build_sam(
 40        encoder_embed_dim=1024,
 41        encoder_depth=24,
 42        encoder_num_heads=16,
 43        encoder_global_attn_indexes=[5, 11, 17, 23],
 44        checkpoint=checkpoint,
 45        num_multimask_outputs=num_multimask_outputs,
 46        image_size=image_size,
 47    )
 48
 49
 50def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
 51    return _build_sam(
 52        encoder_embed_dim=768,
 53        encoder_depth=12,
 54        encoder_num_heads=12,
 55        encoder_global_attn_indexes=[2, 5, 8, 11],
 56        checkpoint=checkpoint,
 57        num_multimask_outputs=num_multimask_outputs,
 58        image_size=image_size,
 59    )
 60
 61
 62sam_model_registry = {
 63    "default": build_sam_vit_h,
 64    "vit_h": build_sam_vit_h,
 65    "vit_l": build_sam_vit_l,
 66    "vit_b": build_sam_vit_b,
 67}
 68
 69
 70def _build_sam(
 71    encoder_embed_dim,
 72    encoder_depth,
 73    encoder_num_heads,
 74    encoder_global_attn_indexes,
 75    checkpoint=None,
 76    num_multimask_outputs=3,
 77    image_size=1024,
 78):
 79    prompt_embed_dim = 256
 80    vit_patch_size = 16
 81    image_embedding_size = image_size // vit_patch_size
 82    sam = Sam(
 83        image_encoder=ImageEncoderViT(
 84            depth=encoder_depth,
 85            embed_dim=encoder_embed_dim,
 86            img_size=image_size,
 87            mlp_ratio=4,
 88            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
 89            num_heads=encoder_num_heads,
 90            patch_size=vit_patch_size,
 91            qkv_bias=True,
 92            use_rel_pos=True,
 93            global_attn_indexes=encoder_global_attn_indexes,
 94            window_size=14,
 95            out_chans=prompt_embed_dim,
 96        ),
 97        prompt_encoder=PromptEncoder(
 98            embed_dim=prompt_embed_dim,
 99            image_embedding_size=(image_embedding_size, image_embedding_size),
100            input_image_size=(image_size, image_size),
101            mask_in_chans=16,
102        ),
103        mask_decoder=MaskDecoder(
104            num_multimask_outputs=num_multimask_outputs,
105            transformer=TwoWayTransformer(
106                depth=2,
107                embedding_dim=prompt_embed_dim,
108                mlp_dim=2048,
109                num_heads=8,
110            ),
111            transformer_dim=prompt_embed_dim,
112            iou_head_depth=3,
113            iou_head_hidden_dim=256,
114        ),
115        pixel_mean=[123.675, 116.28, 103.53],
116        pixel_std=[58.395, 57.12, 57.375],
117    )
118    sam.eval()
119    if checkpoint is not None:
120        with open(checkpoint, "rb") as f:
121            state_dict = torch.load(f)
122        sam.load_state_dict(state_dict)
123    return sam
def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
25    return _build_sam(
26        encoder_embed_dim=1280,
27        encoder_depth=32,
28        encoder_num_heads=16,
29        encoder_global_attn_indexes=[7, 15, 23, 31],
30        checkpoint=checkpoint,
31        num_multimask_outputs=num_multimask_outputs,
32        image_size=image_size,
33    )
def build_sam(checkpoint=None, num_multimask_outputs=3, image_size=1024):
24def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024):
25    return _build_sam(
26        encoder_embed_dim=1280,
27        encoder_depth=32,
28        encoder_num_heads=16,
29        encoder_global_attn_indexes=[7, 15, 23, 31],
30        checkpoint=checkpoint,
31        num_multimask_outputs=num_multimask_outputs,
32        image_size=image_size,
33    )
def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
39def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024):
40    return _build_sam(
41        encoder_embed_dim=1024,
42        encoder_depth=24,
43        encoder_num_heads=16,
44        encoder_global_attn_indexes=[5, 11, 17, 23],
45        checkpoint=checkpoint,
46        num_multimask_outputs=num_multimask_outputs,
47        image_size=image_size,
48    )
def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
51def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024):
52    return _build_sam(
53        encoder_embed_dim=768,
54        encoder_depth=12,
55        encoder_num_heads=12,
56        encoder_global_attn_indexes=[2, 5, 8, 11],
57        checkpoint=checkpoint,
58        num_multimask_outputs=num_multimask_outputs,
59        image_size=image_size,
60    )
sam_model_registry = {'default': <function build_sam_vit_h>, 'vit_h': <function build_sam_vit_h>, 'vit_l': <function build_sam_vit_l>, 'vit_b': <function build_sam_vit_b>}