Skip to content

Commit

Permalink
add fn_call example and bugfix gen_keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Feb 7, 2024
1 parent c722524 commit 6d13dcb
Show file tree
Hide file tree
Showing 22 changed files with 253 additions and 113 deletions.
4 changes: 2 additions & 2 deletions examples/assistant_data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def app():
# chat
messages = []
while True:
# query example: Help me draw a line chart to show the changes in stock prices
# query example: pd.head the file first and then help me draw a line chart to show the changes in stock prices
query = input('user question: ')
# file example: resource/stock_prices.csv
file = input('file url (press enter if no file): ')
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
Expand Down
2 changes: 1 addition & 1 deletion examples/assistant_doctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def app():
# query example: 医生,可以帮我看看我是否健康吗?
query = input('user question: ')
# file example: https://pic4.zhimg.com/80/v2-2c8eedf3e12386fedcd5589cf5575717_720w.webp
file = input('file url (press enter if no file): ')
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
Expand Down
2 changes: 1 addition & 1 deletion examples/assistant_growing_girl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def app():
# query example: 请开始创作!
query = input('user question: ')
# file example: resource/growing_girl.pdf
file = input('file url (press enter if no file): ')
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
Expand Down
2 changes: 1 addition & 1 deletion examples/assistant_weather_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def app():
# query example: 海淀区天气
query = input('user question: ')
# file example: resource/poem.pdf
file = input('file url (press enter if no file): ')
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
Expand Down
114 changes: 114 additions & 0 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Reference: https://platform.openai.com/docs/guides/function-calling
import json

from qwen_agent.llm import get_chat_model


# Example dummy function hard coded to return the same weather
# In production, this could be your backend API or an external API
def get_current_weather(location, unit='fahrenheit'):
"""Get the current weather in a given location"""
if 'tokyo' in location.lower():
return json.dumps({
'location': 'Tokyo',
'temperature': '10',
'unit': 'celsius'
})
elif 'san francisco' in location.lower():
return json.dumps({
'location': 'San Francisco',
'temperature': '72',
'unit': 'fahrenheit'
})
elif 'paris' in location.lower():
return json.dumps({
'location': 'Paris',
'temperature': '22',
'unit': 'celsius'
})
else:
return json.dumps({'location': location, 'temperature': 'unknown'})


def run_conversation():
llm = get_chat_model({
# Use the model service provided by DashScope:
'model': 'qwen-max',
'model_server': 'dashscope',
# 'api_key': 'YOUR_DASHSCOPE_API_KEY',
# It will use the `DASHSCOPE_API_KEY' environment variable if 'api_key' is not set.

# Use your own model service compatible with OpenAI API:
# 'model': 'Qwen/Qwen1.5-72B-Chat',
# 'model_server': 'http://localhost:8000/v1', # api_base
# 'api_key': 'EMPTY',
})

# Step 1: send the conversation and available functions to the model
messages = [{
'role': 'user',
'content': "What's the weather like in San Francisco?"
}]
functions = [{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description':
'The city and state, e.g. San Francisco, CA',
},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit']
},
},
'required': ['location'],
},
}]

print('# Assistant Response 1:')
responses = llm.chat(messages=messages, functions=functions, stream=False)
print(responses)

messages.extend(responses) # extend conversation with assistant's reply

# Step 2: check if the model wanted to call a function
last_response = messages[-1]
if last_response.get('function_call', None):

# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
'get_current_weather': get_current_weather,
} # only one function in this example, but you can have multiple
function_name = last_response['function_call']['name']
function_to_call = available_functions[function_name]
function_args = json.loads(last_response['function_call']['arguments'])
function_response = function_to_call(
location=function_args.get('location'),
unit=function_args.get('unit'),
)
print('# Function Response:')
print(function_response)

# Step 4: send the info for each function call and function response to the model
messages.append({
'role': 'function',
'name': function_name,
'content': function_response,
}) # extend conversation with function response

print('# Assistant Response 2:')
responses = llm.chat(
messages=messages,
functions=functions,
stream=False,
) # get a new response from the model where it can see the function response
print(responses)


if __name__ == '__main__':
run_conversation()
2 changes: 1 addition & 1 deletion examples/multi_agent_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def app():
# image example: https://img01.sc115.com/uploads/sc/jpgs/1505/apic11540_sc115.com.jpg
image = input('image url (press enter if no image): ')
# file example: resource/poem.pdf
file = input('file url (press enter if no file): ')
file = input('file url (press enter if no file): ').strip()
if not query:
print('user question cannot be empty!')
continue
Expand Down
2 changes: 1 addition & 1 deletion examples/visual_storytelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def app():
while True:
query = input('user question: ')
# image example: https://img01.sc115.com/uploads3/sc/vector/201809/51413-20180914205509.jpg
image = input('image url: ')
image = input('image url: ').strip()

if not image:
print('image cannot be empty!')
Expand Down
3 changes: 1 addition & 2 deletions qwen_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def __init__(self,
:param kwargs: other potential parameters
"""
if isinstance(llm, dict):
self.llm_config = llm
self.llm = get_chat_model(self.llm_config)
self.llm = get_chat_model(llm)
else:
self.llm = llm

Expand Down
27 changes: 15 additions & 12 deletions qwen_agent/agents/assistant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from typing import Dict, Iterator, List, Optional, Union

from qwen_agent import Agent
from qwen_agent.llm import BaseChatModel
from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, SYSTEM
from qwen_agent.llm.schema import (CONTENT, DEFAULT_SYSTEM_MESSAGE, FUNCTION,
ROLE, SYSTEM, Message)
from qwen_agent.log import logger
from qwen_agent.memory import Memory
from qwen_agent.utils.utils import format_knowledge_to_source_and_content
Expand Down Expand Up @@ -48,11 +50,12 @@ def __init__(self,
self.mem = Memory(llm=self.llm, files=files)

def _run(self,
messages: List[Dict],
messages: List[Message],
lang: str = 'zh',
max_ref_token: int = 4000,
**kwargs) -> Iterator[List[Dict]]:
system_prompt = ''
messages = copy.deepcopy(messages)
knowledge_prompt = ''

# retrieval knowledge from files
*_, last = self.mem.run(messages=messages, max_ref_token=max_ref_token)
Expand All @@ -64,13 +67,13 @@ def _run(self,
for k in knowledge:
snippets.append(KNOWLEDGE_SNIPPET[lang].format(
source=k['source'], content=k['content']))
system_prompt += KNOWLEDGE_TEMPLATE[lang].format(
knowledge_prompt += KNOWLEDGE_TEMPLATE[lang].format(
knowledge='\n\n'.join(snippets))

if messages[0][ROLE] == SYSTEM:
messages[0][CONTENT] += system_prompt
messages[0][CONTENT] += knowledge_prompt
else:
messages.insert(0, {ROLE: SYSTEM, CONTENT: system_prompt})
messages.insert(0, Message(role=SYSTEM, content=knowledge_prompt))

max_turn = 5
response = []
Expand All @@ -81,19 +84,19 @@ def _run(self,
functions=[
func.function for func in self.function_map.values()
])
output = []
output: List[Message] = []
for output in output_stream:
yield response + output
response.extend(output)
messages.extend(output)
use_tool, action, action_input, _ = self._detect_tool(response[-1])
if use_tool:
observation = self._call_tool(action, action_input)
fn_msg = {
'role': 'function',
'name': action,
'content': observation,
}
fn_msg = Message(
role=FUNCTION,
name=action,
content=observation,
)
messages.append(fn_msg)
response.append(fn_msg)
yield response
Expand Down
35 changes: 18 additions & 17 deletions qwen_agent/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def chat(
stream=stream,
delta_stream=delta_stream)
if isinstance(output, list):
output = self._postprocess_messages_for_func_call(output)
output = self._postprocess_messages_for_fn_call(output)
else:
output = self._postprocess_messages_iterator_for_func_call(
output = self._postprocess_messages_iterator_for_fn_call(
output)
else:
messages = self._preprocess_messages(messages)
Expand All @@ -108,7 +108,7 @@ def _chat_with_functions(
delta_stream: bool = False
) -> Union[List[Message], Iterator[List[Message]]]:

messages = self._prepend_tool_message(messages, functions)
messages = self._prepend_fn_call_system(messages, functions)
messages = self._preprocess_messages(messages)

if messages and messages[-1][ROLE] == ASSISTANT:
Expand Down Expand Up @@ -147,8 +147,9 @@ def _chat_no_stream(
) -> List[Message]:
raise NotImplementedError

def _prepend_tool_message(
self, messages: List[Message],
@staticmethod
def _prepend_fn_call_system(
messages: List[Message],
functions: Optional[List[Dict]]) -> List[Message]:
# prepend tool react prompt
tool_desc_template = FN_CALL_TEMPLATE['en']
Expand All @@ -173,19 +174,19 @@ def _prepend_tool_message(
@staticmethod
def _format_as_multimodal_messages(
messages: List[Message]) -> List[Message]:
new_messages = []
multimodal_messages = []
for msg in messages:
role = msg[ROLE]
assert role in (USER, ASSISTANT, SYSTEM, FUNCTION)
if role == FUNCTION:
new_messages.append(msg)
multimodal_messages.append(msg)
continue

content = []
if isinstance(msg[CONTENT], str):
if isinstance(msg[CONTENT], str): # if text content
if msg[CONTENT]:
content = [ContentItem(text=msg[CONTENT])]
elif isinstance(msg[CONTENT], list):
elif isinstance(msg[CONTENT], list): # if multimodal content
files = []
for item in msg[CONTENT]:
for k, v in item.model_dump().items():
Expand All @@ -195,7 +196,7 @@ def _format_as_multimodal_messages(
content.append(item)
if k in ('file', 'image'):
files.append(v)
if files:
if (msg[ROLE] in (SYSTEM, USER)) and files:
has_zh = has_chinese_chars(content)
upload = []
for f in [get_basename_from_url(f) for f in files]:
Expand All @@ -211,20 +212,20 @@ def _format_as_multimodal_messages(
upload.append(f'[file]({f})')
upload = ' '.join(upload)
if has_zh:
upload = f'(上传了 {upload})'
upload = f'(上传了 {upload}\n\n'
else:
upload = f'(Uploaded {upload})'
upload = f'(Uploaded {upload})\n\n'
content = [ContentItem(text=upload)
] + content # insert a text
else:
raise TypeError

new_messages.append(
multimodal_messages.append(
Message(role=role,
content=content,
function_call=msg.function_call))

return new_messages
return multimodal_messages

def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
"""
Expand Down Expand Up @@ -274,7 +275,7 @@ def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
break
return new_messages

def _postprocess_messages_for_func_call(
def _postprocess_messages_for_fn_call(
self, messages: List[Message]) -> List[Message]:
"""
If the model calls function by built-in function call template, convert and display it in function_call format in return.
Expand Down Expand Up @@ -371,11 +372,11 @@ def _postprocess_messages_for_func_call(
new_content)) # no func call
return new_messages

def _postprocess_messages_iterator_for_func_call(
def _postprocess_messages_iterator_for_fn_call(
self,
messages: Iterator[List[Message]]) -> Iterator[List[Message]]:
for m in messages:
yield self._postprocess_messages_for_func_call(m)
yield self._postprocess_messages_for_fn_call(m)

def _convert_messages_to_target_type(
self, messages: List[Message],
Expand Down
Loading

0 comments on commit 6d13dcb

Please sign in to comment.