forked from PaddlePaddle/PaddleMIX
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PP-DocBee v1.0 (PaddlePaddle#862)
Benchamrk | MiniCPM-V 2.0 | SmolVLM | Aquila-VL-2B | Mini-Monkey-2B | InternVL2-2B | InternVL2.5-2B | Qwen2-VL-2B | PP-DocBee -- | -- | -- | -- | -- | -- | -- | -- | -- Model Size | 2.43B | 2.25B | 2.18B | 2.21B | 2.21B | 2.21B | 2.21B | 2.21B DocVQA-val | 71.9(test) | 81.6(test) | 85.0 | 87.4(test) | 86.9(test) | 88.7(test) | 89.2 | 90.1 ChartQA-test | - | - | 76.5 | 76.5 | 76.2 | 79.2 | 73.5 | 74.6 InfoVQA-val | - | - | 58.3 | 60.1(test) | 58.9(test) | 60.9(test) | 64.1 | 65.4 TextVQA-val | 74.1 | 72.7 | 76.4 | 76.0 | 73.4 | 74.3 | 79.7 | 81.2 OCRBench | 605 | - | 77.2 | 79.4 | 781 | 80.4 | 79.4 | 82.8 ChineseOCRBench | - | - | | | - | - | 76.1 | 80.2 内部中文评估集 | - | - | | | 44.1 | - | 52.8 | 60.3
- Loading branch information
1 parent
0af671a
commit 7d556a6
Showing
2 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# PP-DocBee | ||
|
||
|
||
# PP-DocBee | ||
|
||
## 1. 简介 | ||
|
||
PP-DocBee 是一款专注于文档理解的多模态大模型,在中文文档理解任务上具有卓越表现。该模型基于 'Qwen/Qwen2-VL-2BInstruct' 架构,通过近 500 万条文档理解类多模态数据和精选的纯文本数据进行微调优化。 | ||
|
||
## 2. 环境要求 | ||
- **python >= 3.10** | ||
- **paddlepaddle-gpu 要求版本develop** | ||
``` | ||
# 安装示例 | ||
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html | ||
``` | ||
|
||
- **paddlenlp == 3.0.0b2** | ||
|
||
> 注:(默认开启flash_attn)使用flash_attn 要求A100/A800显卡或者H20显卡。V100请用float16推理。 | ||
|
||
## 3. 在线体验 | ||
|
||
<p align="center"> | ||
<video width="80%" height="auto" controls> | ||
<source src="https://github.com/user-attachments/assets/8e74c364-6d65-4930-b873-6fd5df263d9a" type="video/mp4"> | ||
您的浏览器不支持视频标签 | ||
</video> | ||
</p> | ||
|
||
|
||
我们提供了在线体验环境,您可以通过[AI Studio](https://aistudio.baidu.com/application/detail/60135)快速体验 PP-DocBee 的功能。 | ||
|
||
## 4. 使用指南 | ||
|
||
### 4.1 模型推理 | ||
|
||
下面展示了一个表格识别的示例: | ||
|
||
<p align="center"> | ||
<img src="https://github.com/user-attachments/assets/6a03a848-c396-4b2f-a7f3-47ff1441c750" width="50%" alt="示例图片"/> | ||
</p> | ||
|
||
```bash | ||
python paddlemix/examples/ppdocbee/single_image_infer.py \ | ||
--model_path "PaddleMIX/PPDocBee-2B" \ | ||
--image_file "your_image_path" \ | ||
--question "识别这份表格的内容" | ||
``` | ||
|
||
输出示例: | ||
``` | ||
名次 国家/地区 金牌 银牌 铜牌 奖牌总数 | ||
1 中国(CHN) 48 22 30 100 | ||
2 美国(USA) 36 39 37 112 | ||
3 俄罗斯(RUS) 24 13 23 60 | ||
4 英国(GBR) 19 13 19 51 | ||
5 德国(GER) 16 11 14 41 | ||
6 澳大利亚(AUS) 14 15 17 46 | ||
7 韩国(KOR) 13 11 8 32 | ||
8 日本(JPN) 9 8 8 25 | ||
9 意大利(ITA) 8 9 10 27 | ||
10 法国(FRA) 7 16 20 43 | ||
11 荷兰(NED) 7 5 4 16 | ||
12 乌克兰(UKR) 7 4 11 22 | ||
13 肯尼亚(KEN) 6 4 6 16 | ||
14 西班牙(ESP) 5 11 3 19 | ||
15 牙买加(JAM) 5 4 2 11 | ||
``` | ||
|
||
## 5. 性能评测 | ||
|
||
### 5.1 准确率评测 | ||
|
||
|
||
| Benchamrk | MiniCPM-V 2.0 | InternVL-2B | SmolVLM | Qwen2-VL-2B | **PP-DocBee** | | ||
| :-------------: | :--------: | :---------: | :---------: |:---------: | :-----------: | | ||
| Model Size | 2.43B | 2.21B | 2.25B | 2.21B | 2.21B | | ||
| DocVQA-val | 71.9(test)| 86.9(test) | 81.6(test)| 89.2 | **90.1** | | ||
| ChartQA-test | - | **76.2** | - | 73.5 | 74.6 | | ||
| InfoVQA-val | - | 58.9(test) | - | 64.1 | **65.4** | | ||
| TextVQA-val | 74.1 | 73.4 | 72.7 | 79.7 | **81.2** | | ||
| OCRBench | 605 | 781 | - | 794 | **828** | | ||
| ChineseOCRBench | - | - | - | 76.1 | **80.2** | | ||
| 内部中文评估集 | - | - | - | 52.8 | **60.3** | | ||
|
||
|
||
|
||
### 5.2 速度评测 | ||
|
||
|
||
|
||
## 引用 | ||
|
||
如果您在研究中使用了 PP-DocBee,请引用以下论文: | ||
|
||
```BibTeX | ||
@article{Qwen2-VL, | ||
title={Qwen2-VL}, | ||
author={Qwen team}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
import paddle | ||
from paddlenlp.transformers import Qwen2Tokenizer | ||
|
||
from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration | ||
from paddlemix.processors.qwen2_vl_processing import ( | ||
Qwen2VLImageProcessor, | ||
Qwen2VLProcessor, | ||
process_vision_info, | ||
) | ||
from paddlemix.utils.log import logger | ||
|
||
|
||
def main(args): | ||
paddle.seed(seed=0) | ||
compute_dtype = "float16" if args.fp16 else "bfloat16" | ||
if "npu" in paddle.get_device(): | ||
is_bfloat16_supported = True | ||
else: | ||
is_bfloat16_supported = paddle.amp.is_bfloat16_supported() | ||
if compute_dtype == "bfloat16" and not is_bfloat16_supported: | ||
logger.warning("bfloat16 is not supported on your device,change to float32") | ||
compute_dtype = "float32" | ||
|
||
model = Qwen2VLForConditionalGeneration.from_pretrained(args.model_path, dtype="bfloat16") | ||
|
||
image_processor = Qwen2VLImageProcessor() | ||
tokenizer = Qwen2Tokenizer.from_pretrained(args.model_path) | ||
processor = Qwen2VLProcessor(image_processor, tokenizer) | ||
|
||
# min_pixels = 256*28*28 # 200704 | ||
# max_pixels = 1280*28*28 # 1003520 | ||
# processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels) | ||
|
||
messages = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "image", | ||
"image": f"{args.image_file}", | ||
}, | ||
{"type": "text", "text": f"{args.question}"}, | ||
], | ||
} | ||
] | ||
|
||
# Preparation for inference | ||
image_inputs, video_inputs = process_vision_info(messages) | ||
|
||
question = messages[0]["content"][1]["text"] | ||
image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>" | ||
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n" | ||
text = [text] | ||
|
||
inputs = processor( | ||
text=text, | ||
images=image_inputs, | ||
videos=video_inputs, | ||
padding=True, | ||
return_tensors="pd", | ||
) | ||
|
||
if args.benchmark: | ||
import time | ||
|
||
start = 0.0 | ||
total = 0.0 | ||
for i in range(20): | ||
if i > 10: | ||
start = time.time() | ||
with paddle.no_grad(): | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature | ||
) # already trimmed in paddle | ||
output_text = processor.batch_decode( | ||
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
if i > 10: | ||
total += time.time() - start | ||
print("s/it: ", total / 10) | ||
print(f"\nGPU memory usage: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB") | ||
print("output_text:\n", output_text) | ||
|
||
else: | ||
# Inference: Generation of the output | ||
generated_ids = model.generate( | ||
**inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature | ||
) # already trimmed in paddle | ||
output_text = processor.batch_decode( | ||
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
print("output_text:\n", output_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_path", type=str, default="PaddleMIX/PPDocBee-2B") | ||
parser.add_argument("--question", type=str, default="What is written in the image?") | ||
parser.add_argument("--image_file", type=str, default="paddlemix/demo_images/ppdocbee_image1.jpg") | ||
parser.add_argument("--temperature", type=float, default=0.01) | ||
parser.add_argument("--max_new_tokens", type=int, default=128) | ||
parser.add_argument("--fp16", action="store_true") | ||
parser.add_argument("--benchmark", action="store_true") | ||
args = parser.parse_args() | ||
main(args) |