Skip to content

Commit

Permalink
add examples: multimodal tool use with qwen2-vl
Browse files Browse the repository at this point in the history
  • Loading branch information
gewenbin0992 authored and JianxinMa committed Aug 9, 2024
1 parent 3c4f8d0 commit 8f24dbf
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 175 deletions.
57 changes: 0 additions & 57 deletions examples/assistant_angry_girlfriend.py

This file was deleted.

93 changes: 0 additions & 93 deletions examples/assistant_growing_girl.py

This file was deleted.

137 changes: 137 additions & 0 deletions examples/qwen2vl_assistant_tooluse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import re
import uuid
from io import BytesIO
from pprint import pprint
from typing import List, Union

import requests
from PIL import Image

from qwen_agent.agents import FnCallAgent
from qwen_agent.llm.schema import ContentItem
from qwen_agent.tools.base import BaseToolWithFileAccess, register_tool

ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), 'resource')


@register_tool('crop_and_resize')
class CropResize(BaseToolWithFileAccess):
description = '这是一个放大镜功能,截取局部图像并放大从而查看更多细节,如果你无法直接看清细节时可以调用'
parameters = [
{
'name': 'image',
'type': 'string',
'description': '输入图片本地路径或URL',
'required': True
},
{
'name': 'rectangle',
'type': 'string',
'description': '需要截取的局部图像区域,使用左上角坐标和右下角坐标表示(原点在图像左上角、向右为x轴正方向、向下为y轴正方向),格式:(x1,y1),(x2,y2)',
'required': True
},
]

def _extract_coordinates(self, text):
pattern = r'\((\d+),\s*(\d+)\)'
matches = re.findall(pattern, text)
coordinates = [(int(x), int(y)) for x, y in matches]
if len(coordinates) >= 2:
x1, y1 = coordinates[0]
x2, y2 = coordinates[1]
return x1, y1, x2, y2

pattern = r'\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)'
matches = re.findall(pattern, text)
coordinates = [(int(x1), int(y1), int(x2), int(y2)) for x1, y1, x2, y2 in matches]
x1, y1, x2, y2 = coordinates[0]
return coordinates[0]

def _expand_box(self, x1, y1, x2, y2, factor=1):
xc = (x1 + x2) / 2
yc = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
w_new = w * factor
h_new = h * factor
return xc - w_new / 2, yc - h_new / 2, xc + w_new / 2, yc + h_new / 2

def call(self, params: Union[str, dict], files: List[str] = None, **kwargs) -> List[ContentItem]:
super().call(params=params, files=files)
params = self._verify_json_format_args(params)

image_arg = params['image'] # local path or url
rectangle = params['rectangle']

# open image
if image_arg.startswith('http'):
response = requests.get(image_arg)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
elif os.path.exists(image_arg):
image = Image.open(image_arg)
else:
image = Image.open(os.path.join(self.work_dir, image_arg))

coordinates = self._extract_coordinates(rectangle)
x1, y1, x2, y2 = self._expand_box(*coordinates, factor=1.35)

w, h = image.size
x1, y1 = round(x1 / 1000 * w), round(y1 / 1000 * h)
x2, y2 = round(x2 / 1000 * w), round(y2 / 1000 * h)

# remove padding
x1, y1, x2, y2 = max(x1, 0), max(y1, 0), min(x2, w), min(y2, h)

cropped_image = image.crop((x1, y1, x2, y2))

# save
output_path = os.path.abspath(os.path.join(self.work_dir, f'{uuid.uuid4()}.png'))
cropped_image.save(output_path)

return [
ContentItem(image=output_path),
ContentItem(text=f'( 这张放大的局部区域的图片的URL是 {output_path} )'),
]


def test():
llm_cfg_vl = {
# Using Qwen2-VL deployed at any openai-compatible service such as vLLM:
# 'model_type': 'qwenvl_oai',
# 'model': 'Qwen/Qwen2-VL-72B-Instruct',
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',

# Using Qwen2-VL provided by Alibaba Cloud DashScope:
# 'model_type': 'qwenvl_dashscope',
# 'model': 'qwen2-vl-72b-instruct',
# 'api_key': os.getenv('DASHSCOPE_API_KEY'),

# TODO: Use qwen2-vl instead once qwen2-vl is released.
'model_type': 'qwenvl_dashscope',
'model': 'qwen-vl-max',
'api_key': os.getenv('DASHSCOPE_API_KEY'),
'generate_cfg': dict(max_retries=10,)
}

agent = FnCallAgent(function_list=['crop_and_resize'], llm=llm_cfg_vl)
messages = [{
'role':
'user',
'content': [
{
'image': os.path.abspath(os.path.join(ROOT_RESOURCE, 'screenshot_with_plot.jpeg'))
},
{
'text': '调用工具放大右边的表格'
},
],
}]
response = agent.run_nonstream(messages=messages)
pprint(response, indent=4)


if __name__ == '__main__':
test()
81 changes: 81 additions & 0 deletions examples/qwen2vl_function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import urllib.parse

from qwen_agent.llm import get_chat_model
from qwen_agent.llm.schema import ContentItem


def image_gen(prompt: str) -> str:
prompt = urllib.parse.quote(prompt)
image_url = f'https://image.pollinations.ai/prompt/{prompt}'
return image_url


def test():
# Config for the model
llm_cfg_oai = {
# Using Qwen2-VL deployed at any openai-compatible service such as vLLM:
'model_type': 'qwenvl_oai',
'model': 'Qwen/Qwen2-VL-72B-Instruct',
'model_server': 'http://localhost:8000/v1', # api_base
'api_key': 'EMPTY',
}
llm = get_chat_model(llm_cfg_oai)

# Initial conversation
messages = [{
'role':
'user',
'content': [{
'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg'
}, {
'text': '图片中的内容是什么?请画一张内容相同,风格类似的图片。'
}]
}]

functions = [
{
'name': 'image_gen',
'description': 'AI绘画(图像生成)服务,输入文本描述,返回根据文本信息绘制的图片URL。',
'parameters': {
'name': 'prompt',
'type': 'string',
'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述,使用英文',
'required': True
}
},
]

print('# Assistant Response 1:')
responses = []
for responses in llm.chat(messages=messages, functions=functions, stream=True):
print(responses)
messages.extend(responses)

for rsp in responses:
if rsp.get('function_call', None):
func_name = rsp['function_call']['name']
if func_name == 'image_gen':
func_args = json.loads(rsp['function_call']['arguments'])
image_url = image_gen(func_args['prompt'])
print('# Function Response:')
func_rsp = {
'role': 'function',
'name': func_name,
'content': [ContentItem(image=image_url),
ContentItem(text=f'( 这张图片的URL是 {image_url} )')],
}
messages.append(func_rsp)
print(func_rsp)
else:
raise NotImplementedError

print('# Assistant Response 2:')
responses = []
for responses in llm.chat(messages=messages, functions=functions, stream=True):
print(responses)
messages.extend(responses)


if __name__ == '__main__':
test()
Binary file removed examples/resource/blood_routine.pdf
Binary file not shown.
Binary file removed examples/resource/growing_girl.pdf
Binary file not shown.
Binary file added examples/resource/screenshot_with_plot.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 8f24dbf

Please sign in to comment.