Skip to content

Commit

Permalink
Support internvl2 grounding (modelscope#1473)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jul 24, 2024
1 parent 196746f commit 0c40a40
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 230 deletions.
9 changes: 7 additions & 2 deletions docs/source/Multi-Modal/florence最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,14 @@ CUDA_VISIBLE_DEVICES=0 swift sft \
1. 对于给定bounding box询问目标的任务, 在query中指定`<bbox>`, 在response中指定`<ref-object>`, 在`objects`提供目标和bounding box具体信息
2. 对于给定目标询问bounding box的任务,在query中指定`<ref-object>`, 在response中指定`<bbox>`, 在`objects`提供目标和bounding box具体信息
```jsonl
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[[\"bottom right sandwich\", [331, 266, 612, 530]]]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[[\"bottom right sandwich\", [331, 266, 612, 530]]]" }
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
```
上述objects字段中包含了一个json string,其中有四个字段:
a. caption bbox对应的物体描述
b. bbox 坐标 建议给四个整数(而非float型),分别是x_min,y_min,x_max,y_max四个值
c. bbox_type: bbox类型 目前支持三种:real/norm_1000/norm_1,分别代表实际像素值坐标/千分位比例坐标/归一化比例坐标
d. image: bbox对应的图片是第几张, 索引从0开始
## 微调后推理
Expand Down
16 changes: 16 additions & 0 deletions docs/source/Multi-Modal/internvl最佳实践.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,22 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \
{"query": "Describe this video in detail. Don't repeat", "response": "xxxxxxxxx", "history": [], "videos": ["video_path"]}
```

**InternVL2**模型支持grounding任务的训练,数据参考下面的格式:
```jsonl
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
```
上述objects字段中包含了一个json string,其中有四个字段:
a. caption bbox对应的物体描述
b. bbox 坐标 建议给四个整数(而非float型),分别是x_min,y_min,x_max,y_max四个值
c. bbox_type: bbox类型 目前支持三种:real/norm_1000/norm_1,分别代表实际像素值坐标/千分位比例坐标/归一化比例坐标
d. image: bbox对应的图片是第几张, 索引从0开始
上述格式会被转换为InternVL2可识别的格式,具体来说:
```jsonl
{"query": "Find <ref>the man</ref>", "response": "<box> [[200, 200, 600, 600]] </box>"}
```
也可以直接传入上述格式,但是注意坐标请使用千分位坐标。

## 微调后推理
直接推理:
```shell
Expand Down
12 changes: 10 additions & 2 deletions docs/source_en/Multi-Modal/florence-best-pratice.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,17 @@ Currently, two types of custom grounding tasks are supported:
1. For tasks asking about the target for a given bounding box, specify `<bbox>` in the query, `<ref-object>` in the response, and provide the target and bounding box details in objects.
2. For tasks asking about the bounding box for a given target, specify `<ref-object>` in the query, `<bbox>` in the response, and provide the target and bounding box details in objects.
```jsonl
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[[\"bottom right sandwich\", [331, 266, 612, 530]]]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[[\"bottom right sandwich\", [331, 266, 612, 530]]]" }
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
```
The `objects` field contains a JSON string with four fields:
1. `caption`: Description of the object corresponding to the bounding box (bbox)
2. `bbox`: Coordinates of the bounding box. It is recommended to provide four integers (rather than float values), specifically `x_min`, `y_min`, `x_max`, and `y_max`.
3. `bbox_type`: Type of the bounding box. Currently, three types are supported: `real`, `norm_1000`, and `norm_1`, which respectively represent actual pixel value coordinates, thousandth ratio coordinates, and normalized ratio coordinates.
4. `image`: The index of the image corresponding to the bounding box. The index starts from 0.
Let me know if you need further assistance!
## Inference after Fine-tuning
Direct inference:
Expand Down
17 changes: 17 additions & 0 deletions docs/source_en/Multi-Modal/internvl-best-practice.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,23 @@ The **InternVL2** model supports training with video datasets without the need t
{"query": "Describe this video in detail. Don't repeat", "response": "xxxxxxxxx", "history": [], "videos": ["video_path"]}
```

The **InternVL2** model supports training for grounding tasks, with data referenced in the following format:
```jsonl
{"query": "Find <bbox>", "response": "<ref-object>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
{"query": "Find <ref-object>", "response": "<bbox>", "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]" }
```
The `objects` field contains a JSON string with four fields:
1. **caption**: Description of the object corresponding to the bounding box.
2. **bbox**: Coordinates suggested as four integers (instead of floats), representing the values `x_min`, `y_min`, `x_max`, and `y_max`.
3. **bbox_type**: Type of bounding box. Currently, three types are supported: `real` / `norm_1000` / `norm_1`, representing actual pixel value coordinates / thousandth-scale coordinates / normalized coordinates.
4. **image**: The index of the corresponding image, starting from 0.

This format will be converted to a format recognizable by InternVL2, specifically:
```json
{"query": "Find <ref>the man</ref>", "response": "<box> [[200, 200, 600, 600]] </box>"}
```
You can also directly input the above format, but please ensure that the coordinates use thousandth-scale coordinates.

## Inference after Fine-tuning
Direct inference:
```shell
Expand Down
30 changes: 25 additions & 5 deletions swift/llm/export.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Optional
from typing import List, Optional

import json
import torch
Expand Down Expand Up @@ -93,6 +93,25 @@ def gptq_model_quantize(model, tokenizer):
return gptq_quantizer


def replace_and_concat(template: 'Template', template_list: List, placeholder: str, keyword: str):
final_str = ''
for t in template_list:
if isinstance(t, str):
final_str += t.replace(placeholder, keyword)
elif isinstance(t, (tuple, list)):
if isinstance(t[0], int):
final_str += template.tokenizer.decode(t)
else:
for attr in t:
if attr == 'bos_token_id':
final_str += template.tokenizer.bos_token
elif attr == 'eos_token_id':
final_str += template.tokenizer.eos_token
else:
raise ValueError(f'Unknown token: {attr}')
return final_str


def llm_export(args: ExportArguments) -> None:
global _args, template
logger.info(f'args: {args}')
Expand Down Expand Up @@ -131,14 +150,15 @@ def llm_export(args: ExportArguments) -> None:
with open(os.path.join(args.ollama_output_dir, 'Modelfile'), 'w') as f:
f.write(f'FROM {model_dir}\n')
f.write(f'TEMPLATE """{{{{ if .System }}}}'
f'{template.system_prefix[0].replace("{{SYSTEM}}", "{{ .System }}")}'
f'{replace_and_concat(template, template.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ else }}}}{replace_and_concat(template, template.prefix, "", "")}'
f'{{{{ end }}}}')
f.write(f'{{{{ if .Prompt }}}}'
f'{template.prompt[0].replace("{{QUERY}}", "{{ .Prompt }}")}'
f'{replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
f'{{{{ end }}}}')
f.write('{{ .Response }}')
f.write(template.suffix[0] + '"""\n')
f.write(f'PARAMETER stop "{template.suffix[0]}"\n')
f.write(replace_and_concat(template, template.suffix, '', '') + '"""\n')
f.write(f'PARAMETER stop "{replace_and_concat(template, template.suffix, "", "")}"\n')
if args.stop_words:
for stop_word in args.stop_words:
f.write(f'PARAMETER stop "{stop_word}"\n')
Expand Down
16 changes: 13 additions & 3 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,12 @@ def preprocess(row):
bbox[i] = round(float(bbox[i]))
res = {}

objects = [[caption, bbox]]
objects = [{
'caption': caption,
'bbox': bbox,
'bbox_type': 'real',
'image': 0,
}]
media_tag(res, [image_path])
res['images'] = [image_path]
res['objects'] = json.dumps(objects, ensure_ascii=False)
Expand Down Expand Up @@ -1248,7 +1253,12 @@ def preprocess(row):
bbox[i] = round(float(bbox[i]))
res = {}

objects = [[caption, bbox]]
objects = [{
'caption': caption,
'bbox': bbox,
'bbox_type': 'real',
'image': 0,
}]
media_tag(res, [image_path])
res['images'] = [image_path]
res['objects'] = json.dumps(objects, ensure_ascii=False)
Expand Down Expand Up @@ -1683,7 +1693,7 @@ def preprocess_row(row):
start_end_pairs.append(ref_exp[0:2])

object_part = caption[int(start):int(end)]
objects.append([object_part, ref_exp[2:6]])
objects.append({'caption': object_part, 'bbox': ref_exp[2:6], 'bbox_type': 'real', 'image': 0})

start_end_pairs.sort(key=lambda x: (x[0], x[1]))
if has_overlap(start_end_pairs):
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/utils/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class MediaTag:
'en': [('<ref-object>', '<bbox>'), ('The positions of <ref-object> is', '<bbox>'),
('Find the positions of <ref-object>', '<bbox>'), ('Where is <ref-object>', '<bbox>'),
('Find <ref-object>', '<bbox>'), ('Show me <ref-object>', '<bbox>'),
('Detect <ref-object>', '<bbox>'), ('Locate <ref-object>', '<bbox>'),
('Tell me the location of <ref-object>', '<bbox>'), ('Give the location of <ref-object>', '<bbox>'),
('Provide the bounding box coordinate of <ref-object>', '<bbox>')],
'zh': [('<ref-object>', '<bbox>'), ('<ref-object>的位置在图片中', '<bbox>'), ('<ref-object>在图片中', '<bbox>'),
('<ref-object>在', '<bbox>'), ('找到<ref-object>的位置', '<bbox>'), ('<ref-object>在哪里', '<bbox>'),
Expand Down
19 changes: 11 additions & 8 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ class LoRATM(NamedTuple):
'o_proj',
]
minicpm_llama = r'.*model\.layers\.(?:[0-9]|[12][0-9]|3[01])\.(?:self_attn\.(?:q_proj|k_proj|v_proj))'
internvl2 = r'.*(wqkv|wo|w[123]|mlp1\.(1|3))$'
internvl2_llama = r'.*(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj|mlp1\.(1|3))$'
internvl2_phi3 = r'.*(qkv_proj|o_proj|gate_up_proj|down_proj|mlp1\.(1|3))$'
# compat
llama2 = llama

Expand All @@ -535,7 +538,7 @@ class LoRATM(NamedTuple):
def register_model(
model_type: str,
model_id_or_path: Optional[str],
lora_target_modules: Optional[List[str]] = None,
lora_target_modules: Optional[Union[List[str], str]] = None,
template: str = TemplateType.default,
get_function: Optional[GetModelTokenizerFunction] = None,
*,
Expand Down Expand Up @@ -3787,7 +3790,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_1b,
'OpenGVLab/InternVL2-1B',
LoRATM.llama,
LoRATM.internvl2_llama,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3797,7 +3800,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_2b,
'OpenGVLab/InternVL2-2B',
LoRATM.internlm2,
LoRATM.internvl2,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3807,7 +3810,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_4b,
'OpenGVLab/InternVL2-4B',
LoRATM.phi3,
LoRATM.internvl2_phi3,
TemplateType.internvl2_phi3,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3817,7 +3820,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_8b,
'OpenGVLab/InternVL2-8B',
LoRATM.internlm2,
LoRATM.internvl2,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3827,7 +3830,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_26b,
'OpenGVLab/InternVL2-26B',
LoRATM.internlm2,
LoRATM.internvl2,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3837,7 +3840,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_40b,
'OpenGVLab/InternVL2-40B',
LoRATM.llama,
LoRATM.internvl2_llama,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand All @@ -3847,7 +3850,7 @@ def patch_internvl_forward(model) -> None:
@register_model(
ModelType.internvl2_llama3_76b,
'OpenGVLab/InternVL2-Llama3-76B',
LoRATM.llama,
LoRATM.internvl2_llama,
TemplateType.internvl2,
requires=['transformers>=4.35', 'timm'],
support_flash_attn=True,
Expand Down
Loading

0 comments on commit 0c40a40

Please sign in to comment.