Skip to content

Commit

Permalink
[Feature] Add MobileSAM (PaddlePaddle#3349)
Browse files Browse the repository at this point in the history
* [Feature] Add MobileSAM

* update

* update

* rm weight

* update tiny_vit_sam.py

* fix

* update

* Update build_sam.py

---------

Co-authored-by: shiyutang <[email protected]>
  • Loading branch information
Asthestarsfalll and shiyutang authored Jul 31, 2023
1 parent 700366a commit e61765f
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 12 deletions.
6 changes: 4 additions & 2 deletions contrib/SegmentAnything/scripts/amg_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams"
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'vit_t':
"https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam"
}

parser = argparse.ArgumentParser(description=(
Expand All @@ -49,7 +51,7 @@
type=str,
default="vit_l",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", )
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )

parser.add_argument(
"--convert-to-rle",
Expand Down
11 changes: 7 additions & 4 deletions contrib/SegmentAnything/scripts/promt_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
'vit_l':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams"
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'vit_t':
"https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam"
}


Expand All @@ -48,19 +50,20 @@ def get_args():
type=str,
default="vit_l",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", )
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']",
)
parser.add_argument(
'--point_prompt',
type=int,
nargs='+',
default=None,
help='point promt.')
help='point prompt.')
parser.add_argument(
'--box_prompt',
type=int,
nargs='+',
default=None,
help='box promt format as xyxy.')
help='box prompt format as xyxy.')
parser.add_argument(
'--output_path',
type=str,
Expand Down
4 changes: 3 additions & 1 deletion contrib/SegmentAnything/scripts/text_to_sam_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_l/model.pdparams",
'vit_b':
"https://bj.bcebos.com/paddleseg/dygraph/paddlesegAnything/vit_b/model.pdparams",
'vit_t':
"https://paddleseg.bj.bcebos.com/dygraph/paddlesegAnything/vit_t/model.pdparam",
'clip_b_32':
"https://bj.bcebos.com/paddleseg/dygraph/clip/vit_b_32_pretrain/clip_vit_b_32.pdparams"
}
Expand All @@ -53,7 +55,7 @@
type=str,
default="vit_h",
required=True,
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b']", )
help="The type of model to load, in ['vit_h', 'vit_l', 'vit_b', 'vit_t']", )


def download(img):
Expand Down
50 changes: 48 additions & 2 deletions contrib/SegmentAnything/segment_anything/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from paddleseg.utils import load_entire_model

from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT


def build_sam_vit_h(checkpoint=None):
Expand Down Expand Up @@ -52,14 +52,60 @@ def build_sam_vit_b(checkpoint=None):
checkpoint=checkpoint, )


def build_sam_vit_t(checkpoint=None):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
mobile_sam = Sam(
image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
embed_dims=[64, 128, 160, 320],
depths=[2, 2, 6, 2],
num_heads=[2, 4, 5, 10],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.0,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)

mobile_sam.eval()
if checkpoint is not None:
load_entire_model(mobile_sam, checkpoint)
mobile_sam.image_encoder.build_abs()
return mobile_sam

sam_model_registry = {
"default": build_sam,
"vit_h": build_sam,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
"vit_t": build_sam_vit_t,
}


def _build_sam(
encoder_embed_dim,
encoder_depth,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer
from .tiny_vit_sam import TinyViT
5 changes: 3 additions & 2 deletions contrib/SegmentAnything/segment_anything/modeling/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from paddle import nn
from paddle.nn import functional as F

from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .tiny_vit_sam import TinyViT


class Sam(nn.Layer):
Expand All @@ -31,7 +32,7 @@ class Sam(nn.Layer):

def __init__(
self,
image_encoder: ImageEncoderViT,
image_encoder: Union[ImageEncoderViT, TinyViT],
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float]=[123.675, 116.28, 103.53],
Expand Down
Loading

0 comments on commit e61765f

Please sign in to comment.