Skip to content

Commit

Permalink
feat:add xunfei api
Browse files Browse the repository at this point in the history
  • Loading branch information
tianminghui committed Oct 26, 2023
1 parent 3ff1b39 commit f55cc6a
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 10 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

model-config.json
model-config.json

.test.py
8 changes: 8 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
"program": "open-api.py",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: 当前文件",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true
}
]
}
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}
6 changes: 4 additions & 2 deletions adapters/adapter_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


import json
from loguru import logger
from adapters.azure import AzureAdapter
Expand All @@ -8,6 +6,7 @@
from adapters.claude_web import ClaudeWebModel
from adapters.proxy import ProxyAdapter
from adapters.zhipu_api import ZhiPuApiModel
from adapters.xunfei_spark import XunfeiSparkAPIModel

model_instance_dict = {}

Expand All @@ -30,6 +29,9 @@ def get_adapter(instanceKey: str, type: str, **kwargs) -> ModelAdapter:
elif type == "zhipu-api":
model = ZhiPuApiModel(**kwargs)

elif type == "xunfei-spark-api":
model = XunfeiSparkAPIModel(**kwargs)

else:
raise ValueError(f"unknown model type: {type}")

Expand Down
6 changes: 4 additions & 2 deletions adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class ChatCompletionRequest(BaseModel):
model: Optional[str] = "gpt-3.5-turbo"
messages: List[ChatMessage]
functions: Optional[List[Dict]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = None # between 0 and 2 Defaults to 1
top_p: Optional[float] = None # Defaults to 1
max_length: Optional[int] = None
stream: Optional[bool] = False
stop: Optional[List[str]] = None
Expand All @@ -51,11 +51,13 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]


class Usage(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0


class ChatCompletionResponse(BaseModel):
id: str = f"chatcmpl-{str(time.time())}"
model: str
Expand Down
148 changes: 148 additions & 0 deletions adapters/xunfei_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
from typing import Iterator
from adapters.base import ModelAdapter
from adapters.protocol import ChatCompletionRequest, ChatCompletionResponse
from loguru import logger
from clients.xunfei_spark.api.spark_api import SparkAPI
import time
import uuid


class XunfeiSparkAPIModel(ModelAdapter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.app_id = kwargs.pop("app_id")
self.api_key = kwargs.pop("api_key")
self.api_secret = kwargs.pop("api_secret")
self.api_model_version = kwargs.pop("api_model_version")
self.prompt = kwargs.pop(
"prompt", "You need to follow the system settings:{system}"
)
self.config_args = kwargs
self.api_connection = SparkAPI(
self.app_id, self.api_key, self.api_secret, self.api_model_version
)

def chat_completions(
self, request: ChatCompletionRequest
) -> Iterator[ChatCompletionResponse]:
messages = self.openai_to_client_params(request)
kargs = {
"chat_id": uuid.uuid1(),
}
if request.temperature:
# openai 取值0-2 xunfei 0-1
kargs["temperature"] = request.temperature / 2
if request.max_length:
kargs["max_tokens"] = request.max_length

kargs.update(self.config_args)
iter_content = self.api_connection.get_resp_from_messages(messages, **kargs)

if request.stream:
for line in iter_content:
code = line["header"]["code"]
if code != 0:
logger.error(f"请求失败:{line}")
raise Exception(f"请求失败:{line}")
openai_response = self.client_response_2_chatgpt_response_stream(line)
yield ChatCompletionResponse(**openai_response)
else:
openai_response = self.client_response_to_chatgpt_response(iter_content)
yield ChatCompletionResponse(**openai_response)

def openai_to_client_params(self, openai_params: ChatCompletionRequest):
prompt = []
for message in openai_params.messages:
role = message.role
if role in ["function"]:
raise Exception(f"不支持的功能:{role}")
if role == "system": # 将system转为user
role = "user"
content = self.prompt.format(system=message.content)
prompt.append({"role": role, "content": content})
prompt.append({"role": "assistant", "content": "ok"})
else:
content = message.content
prompt.append({"role": role, "content": content})
return prompt

def client_response_2_chatgpt_response_stream(self, resp_json):
completion = resp_json["payload"]["choices"]["text"][0]["content"]
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
if resp_json["payload"]["choices"]["status"] == 2:
usage = resp_json["payload"]["usage"]["text"]
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
total_tokens = usage["total_tokens"]

openai_response = {
"id": resp_json["header"]["sid"],
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "gpt-3.5-turbo-0613",
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
"choices": [
{
"delta": {
"role": "assistant",
"content": completion,
},
"index": 0,
"finish_reason": "stop"
if resp_json["payload"]["choices"]["status"] == 2
else None,
}
],
}
return openai_response

def client_response_to_chatgpt_response(self, iter_resp):
completions = []
id = None
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0

for resp_json in iter_resp:
code = resp_json["header"]["code"]
if code != 0:
logger.error(f"请求失败:{resp_json}")
raise Exception(f"请求失败:{resp_json}")
content = resp_json["payload"]["choices"]["text"][0]["content"]
completions.append(content)
id = resp_json["header"]["sid"]
logger.info(f"resp_json: {resp_json}")
if resp_json["payload"]["choices"]["status"] == 2:
usage = resp_json["payload"]["usage"]["text"]
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
total_tokens = usage["total_tokens"]
openai_response = {
"id": id,
"object": "chat.completion",
"created": int(time.time()),
"model": "gpt-3.5-turbo-0613",
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
"choices": [
{
"delta": {
"role": "assistant",
"content": "".join(completions),
},
"index": 0,
"finish_reason": "stop",
}
],
}
return openai_response
Loading

0 comments on commit f55cc6a

Please sign in to comment.