Skip to content

Commit

Permalink
update more model configs
Browse files Browse the repository at this point in the history
  • Loading branch information
yangapku committed Oct 10, 2022
1 parent cf90126 commit 7b3b1d3
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 17 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

<p align="center">
<br>
<img src="assets/Chinese_CLIP_logo_tp_path.svg" width="400" />
<img src="assets/Chinese_CLIP_logo_tp.svg" width="400" />
<br>
<p>
<p align="center">
<a href="https://opensource.org/licenses/MIT">
<img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg">
<img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg" />
</a>
</p>

本项目为CLIP模型的**中文**版本,使用大规模中文数据进行训练(~2亿图文对),旨在帮助用户实现中文领域的跨模态检索、图像表示等。本项目代码基于<b>[open_clip project](https://github.com/mlfoundations/open_clip)</b>建设,并针对中文领域数据以及在中文数据上实现更好的效果做了优化。本项目提供了API、训练代码和测试代码,下文中将详细介绍细节。
<br><br>

## 新闻
* 2022.9.22 新增ViT-L-14, ViT-L-14-336模型
* 2022.7.13 新增API功能,方便快速调用中文CLIP模型
* 2022.7.8 Chinese CLIP项目正式开源
<br><br>
Expand Down Expand Up @@ -106,9 +107,7 @@ pip install -r requirements.txt
## API快速上手
下面提供一段简单的代码示例说明如何使用中文CLIP的API。开始使用前,请先安装cn_clip:
```bash
# 安装最新的稳定版本
pip install cn_clip
# 或从源代码安装
# 从源代码安装
cd Chinese-CLIP/
pip install -e .
```
Expand All @@ -118,7 +117,9 @@ import torch
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models()) # Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336']

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./')
model.eval()
Expand Down Expand Up @@ -390,7 +391,6 @@ cat output.json
<br><br>

## 后续计划
+ 开源ViT-L-14规模Chinese-CLIP模型(训练中)
+ 提供基于Chinese-CLIP的图文检索demo,以及用户在自己的环境下部署demo的流程
+ 在更多图文检索下游任务验证结果
+ 开源Chinese-CLIP技术报告
Expand Down
2 changes: 1 addition & 1 deletion README_En.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<p>
<p align="center">
<a href="https://opensource.org/licenses/MIT">
<img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg">
<img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-yellow.svg" />
</a>
</p>

Expand Down
13 changes: 13 additions & 0 deletions cn_clip/clip/model_configs/RBT3-chinese.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"vocab_size": 21128,
"text_attention_probs_dropout_prob": 0.1,
"text_hidden_act": "gelu",
"text_hidden_dropout_prob": 0.1,
"text_hidden_size": 768,
"text_initializer_range": 0.02,
"text_intermediate_size": 3072,
"text_max_position_embeddings": 512,
"text_num_attention_heads": 12,
"text_num_hidden_layers": 3,
"text_type_vocab_size": 2
}
7 changes: 7 additions & 0 deletions cn_clip/clip/model_configs/RN50.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"embed_dim": 1024,
"image_resolution": 224,
"vision_layers": "[3,4,6,3]",
"vision_width": 64,
"vision_patch_size": null
}
10 changes: 10 additions & 0 deletions cn_clip/clip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@

_MODELS = {
"ViT-B-16": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-b-16.pt",
"ViT-L-14": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14.pt",
"ViT-L-14-336": "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/clip_cn_vit-l-14-336.pt",
}
_MODEL_INFO = {
"ViT-B-16": {
"struct": "ViT-B-16@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"ViT-L-14": {
"struct": "ViT-L-14@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 224
},
"ViT-L-14-336": {
"struct": "ViT-L-14-336@RoBERTa-wwm-ext-base-chinese",
"input_resolution": 336
}
}

Expand Down
6 changes: 4 additions & 2 deletions cn_clip/eval/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def parse_args():
)
parser.add_argument(
"--vision-model",
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336"],
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "RN50"],
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--text-model",
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese"],
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
default="RoBERTa-wwm-ext-base-chinese",
help="Name of the text backbone to use.",
)
Expand Down Expand Up @@ -123,6 +123,8 @@ def parse_args():

with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
for k, v in json.load(ft).items():
model_info[k] = v

Expand Down
8 changes: 4 additions & 4 deletions cn_clip/eval/zeroshot_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--vision-model",
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336"],
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "RN50"],
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
)
parser.add_argument(
"--text-model",
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese"],
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
default="RoBERTa-wwm-ext-base-chinese",
help="Name of the text backbone to use.",
)
)
parser.add_argument(
"--precision",
choices=["amp", "fp16", "fp32"],
Expand Down
7 changes: 7 additions & 0 deletions cn_clip/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def main():

with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
model_info = json.load(fv)
if isinstance(model_info['vision_layers'], str):
model_info['vision_layers'] = eval(model_info['vision_layers'])
for k, v in json.load(ft).items():
model_info[k] = v

Expand Down Expand Up @@ -103,6 +105,11 @@ def main():
if args.freeze_vision:
for k, v in model.visual.named_parameters():
v.requires_grad = False
# freeze bn running mean and variance
if args.vision_model in ['RN50']:
for m in model.visual.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
logging.info("The visual encoder is freezed during training.")

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_device_rank], find_unused_parameters=False)
Expand Down
4 changes: 2 additions & 2 deletions cn_clip/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def parse_args():
)
parser.add_argument(
"--vision-model",
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336"],
choices=["ViT-B-32", "ViT-B-16", "ViT-L-14", "ViT-L-14-336", "RN50"],
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
Expand All @@ -143,7 +143,7 @@ def parse_args():
)
parser.add_argument(
"--text-model",
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese"],
choices=["RoBERTa-wwm-ext-base-chinese", "RoBERTa-wwm-ext-large-chinese", "RBT3-chinese"],
default="RoBERTa-wwm-ext-base-chinese",
help="Name of the text backbone to use.",
)
Expand Down
6 changes: 6 additions & 0 deletions cn_clip/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
# os.environ["WDS_EPOCH"] = str(epoch)

model.train()
# freeze bn running mean and variance
if args.freeze_vision and args.vision_model in ['RN50']:
RN_visual_modules = model.module.visual.modules() if isinstance(model, nn.parallel.DistributedDataParallel) else model.visual.modules()
for m in RN_visual_modules:
if isinstance(m, nn.BatchNorm2d):
m.eval()

dataloader, sampler = data['train'].dataloader, data['train'].sampler

Expand Down
2 changes: 1 addition & 1 deletion run_scripts/flickr30k_finetune_vit-b-16_rbt-base.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ report_training_batch_acc="--report-training-batch-acc"
# report_training_batch_acc=""

# training hyper-params
context_length=24
context_length=52
warmup=100
batch_size=128
valid_batch_size=128
Expand Down
22 changes: 22 additions & 0 deletions run_scripts/imagenet_zeroshot_eval.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env

# NOTE: We have not officially released our zero-shot classification pipeline yet.
# This script and its corresponding python entryfile may be changed in further releases.

# usage: bash run_scripts/imagenet_zeroshot_eval.sh CKPT_PATH

# only supports single-GPU inference
export CUDA_VISIBLE_DEVICES=0
export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip

DATAPATH=../imagenet-1k_val # provide the path of imagenet-val directory in torchvision dataset format
resume=${1}
vision_model=${2:-ViT-B-16}

python -u cn_clip/eval/zeroshot_evaluation.py \
--imagenet-val="${DATAPATH}" \
--img-batch-size=64 \
--context-length=32 \
--resume=${resume} \
--vision-model=${vision_model} \
--text-model=RoBERTa-wwm-ext-base-chinese

0 comments on commit 7b3b1d3

Please sign in to comment.