Skip to content

Commit

Permalink
支持:在stream模式下处理工具调用 + 调用工具时支持stream输出
Browse files Browse the repository at this point in the history
  • Loading branch information
Samge0 committed Dec 5, 2023
1 parent 0b2b342 commit 95942a7
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 4 deletions.
150 changes: 147 additions & 3 deletions openai_api_demo/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import os
import time
import json
from contextlib import asynccontextmanager
from typing import List, Literal, Optional, Union

Expand All @@ -20,6 +21,7 @@
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModel

from tool_using.tool_register import dispatch_tool
from utils import process_response, generate_chatglm3, generate_stream_chatglm3

MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
Expand Down Expand Up @@ -144,9 +146,53 @@ async def create_chat_completion(request: ChatCompletionRequest):
logger.debug(f"==== request ====\n{gen_params}")

if request.stream:
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")

# Use the stream mode to read the first few characters, if it is not a function call, direct stram output
predict_stream_generator = predict_stream(request.model, gen_params)
output = next(predict_stream_generator)
if not contains_custom_function(output):
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")

# Obtain the result directly at one time and determine whether tools needs to be called.
logger.debug(f"First result output:\n{output}")

function_call = None
if output and request.functions:
try:
function_call = process_response(output, use_tool=True)
except:
logger.warning("Failed to parse tool call")

# CallFunction
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
function_args = json.loads(function_call.arguments)
# Call the tool request api, time-consuming operation
tool_response = dispatch_tool(function_call.name, function_args)

if not gen_params.get("messages"):
gen_params["messages"] = []

gen_params["messages"].append(ChatMessage(
role="assistant",
content=output,
))
gen_params["messages"].append(ChatMessage(
role="function",
name=function_call.name,
content=tool_response,
))

# Streaming output of results after function calls
generate = predict(request.model, gen_params)
return EventSourceResponse(generate, media_type="text/event-stream")

else:
# Handled to avoid exceptions in the above parsing function process.
generate = parse_output_text(request.model, output)
return EventSourceResponse(generate, media_type="text/event-stream")

# Here is the handling of stream = False
response = generate_chatglm3(model, tokenizer, gen_params)

# Remove the first newline character
Expand Down Expand Up @@ -210,7 +256,8 @@ async def predict(model_id: str, params: dict):
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.")
logger.warning(
"Failed to parse tool call, maybe the response is not a tool call or have been answered.")

if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
Expand Down Expand Up @@ -239,11 +286,108 @@ async def predict(model_id: str, params: dict):
yield '[DONE]'


def predict_stream(model_id, gen_params):
"""
The function call is compatible with stream mode output.
The first seven characters are determined.
If not a function call, the stream output is directly generated.
Otherwise, the complete character content of the function call is returned.
:param model_id:
:param gen_params:
:return:
"""
output = ""
is_function_call = False
has_send_first_chunk = False
for new_response in generate_stream_chatglm3(model, tokenizer, gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
output = decoded_unicode

# When it is not a function call and the character length is> 7,
# try to judge whether it is a function call according to the special function prefix
if not is_function_call and len(output) > 7:

# Determine whether a function is called
is_function_call = contains_custom_function(output)
if is_function_call:
continue

# Non-function call, direct stream output
finish_reason = new_response["finish_reason"]
send_msg = delta_text if has_send_first_chunk else output
has_send_first_chunk = True
message = DeltaMessage(
content=send_msg,
role="assistant",
function_call=None,
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=message,
finish_reason=finish_reason
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))

if is_function_call:
yield output
else:
yield '[DONE]'


async def parse_output_text(model_id: str, value: str):
"""
Directly output the text content of value
:param model_id:
:param value:
:return:
"""
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=value),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))

choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'


def contains_custom_function(value: str) -> bool:
"""
Determine whether 'function_call' according to a special function prefix.
For example, the functions defined in "tool_using/tool_register.py" are all "get_xxx" and start with "get_"
[Note] This is not a rigorous judgment method, only for reference.
:param value:
:return:
"""
return value and 'get_' in value


if __name__ == "__main__":

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval()

# Multi-GPU support, use the following two lines instead of the above line, num gpus to your actual number of graphics cards
# from utils import load_model_on_gpus
# model = load_model_on_gpus(MODEL_PATH, num_gpus=2)

else: # CPU, Intel GPU and other GPU can use Float16 Precision Only
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
2 changes: 1 addition & 1 deletion openai_api_demo/openai_api_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,5 @@ def simple_chat(use_stream=True):


if __name__ == "__main__":
function_chat(use_stream=False)
function_chat(use_stream=True)
# simple_chat(use_stream=True)

0 comments on commit 95942a7

Please sign in to comment.