Skip to content

Commit

Permalink
新增 千帆大模型的接入;对llm相关源码加载和逻辑处理做了大优化,可能有bug,待测试
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikaros-521 committed Dec 15, 2023
1 parent afd9242 commit c83b00d
Show file tree
Hide file tree
Showing 11 changed files with 2,392 additions and 2,464 deletions.
1,960 changes: 985 additions & 975 deletions config.json

Large diffs are not rendered by default.

1,960 changes: 985 additions & 975 deletions config.json.bak

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@ pure-protobuf==3.0.0a5
pyaudio
flask
flask_cors
xingchen
xingchen
git+https://github.com/Ikaros-521/WenxinWorkshop-Python-SDK
3 changes: 2 additions & 1 deletion requirements_common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ flask==2.3.3
flask_cors==4.0.0
xingchen==1.0.7
qianfan==0.2.2
socketio==0.2.1
socketio==0.2.1
git+https://github.com/Ikaros-521/WenxinWorkshop-Python-SDK
6 changes: 6 additions & 0 deletions src/Function/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
exports.handler = async event => {
// Log the event argument for debugging and for use in local development.
console.log(JSON.stringify(event, undefined, 2));

return {};
};
4 changes: 4 additions & 0 deletions src/Function/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"name": "function",
"version": "1.0.0"
}
131 changes: 131 additions & 0 deletions tests/test_wenxinworkshop/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import json, logging, traceback
from wenxinworkshop import LLMAPI, EmbeddingAPI, PromptTemplateAPI
from wenxinworkshop import Message, Messages, Texts

# 前往官网:https://cloud.baidu.com/product/wenxinworkshop 申请服务获取

class My_WenXinWorkShop:
def __init__(self, data):
# self.common = Common()
# # 日志文件路径
# file_path = "./log/log-" + self.common.get_bj_time(1) + ".txt"
# Configure_logger(file_path)

self.config_data = data
self.history = []

try:
model_url_map = {
"ERNIEBot": LLMAPI.ERNIEBot,
"ERNIEBot_turbo": LLMAPI.ERNIEBot_turbo,
"ERNIEBot_4_0": LLMAPI.ERNIEBot_4_0,
"BLOOMZ_7B": LLMAPI.BLOOMZ_7B,
"LLAMA_2_7B": LLMAPI.LLAMA_2_7B,
"LLAMA_2_13B": LLMAPI.LLAMA_2_13B,
"LLAMA_2_70B": LLMAPI.LLAMA_2_70B,
"ERNIEBot_4_0": LLMAPI.ERNIEBot_4_0,
"QIANFAN_BLOOMZ_7B_COMPRESSED": LLMAPI.QIANFAN_BLOOMZ_7B_COMPRESSED,
"QIANFAN_CHINESE_LLAMA_2_7B": LLMAPI.QIANFAN_CHINESE_LLAMA_2_7B,
"CHATGLM2_6B_32K": LLMAPI.CHATGLM2_6B_32K,
"AQUILACHAT_7B": LLMAPI.AQUILACHAT_7B,
"ERNIE_BOT_8K": LLMAPI.ERNIE_BOT_8K,
"CODELLAMA_7B_INSTRUCT": LLMAPI.CODELLAMA_7B_INSTRUCT,
"XUANYUAN_70B_CHAT": LLMAPI.XUANYUAN_70B_CHAT,
"CHATLAW": LLMAPI.QIANFAN_BLOOMZ_7B_COMPRESSED,
"QIANFAN_BLOOMZ_7B_COMPRESSED": LLMAPI.CHATLAW,
}

selected_model = self.config_data["model"]
if selected_model in model_url_map:
self.my_bot = LLMAPI(
api_key=self.config_data["api_key"],
secret_key=self.config_data["secret_key"],
url=model_url_map[selected_model]
)
except Exception as e:
logging.error(traceback.format_exc())



def get_resp(self, prompt):
"""请求对应接口,获取返回值
Args:
prompt (str): 你的提问
Returns:
str: 返回的文本回答
"""
try:
# create messages
messages: Messages = []

for history in self.history:
messages.append(Message(
role=history["role"],
content=history["content"]
))

messages.append(Message(
role='user',
content=prompt
))

logging.info(f"self.history={self.history}")

# get response from LLM API
resp_content = self.my_bot(
messages=messages,
temperature=self.config_data["temperature"],
top_p=self.config_data["top_p"],
penalty_score=self.config_data["penalty_score"],
stream=None,
user_id=None,
chunk_size=512
)

# 启用历史就给我记住!
if self.config_data["history_enable"]:
while True:
# 获取嵌套列表中所有字符串的字符数
total_chars = sum(len(item['content']) for item in self.history if 'content' in item)
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.config_data["history_max_len"]:
self.history.pop(0)
self.history.pop(0)
else:
# self.history.pop()
self.history.append({"role": "user", "content": prompt})
self.history.append({"role": "assistant", "content": resp_content})
break

return resp_content

except Exception as e:
logging.error(e)

return None

if __name__ == '__main__':
# 配置日志输出格式
logging.basicConfig(
level=logging.DEBUG, # 设置日志级别,可以根据需求调整
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

data = {
"model": "ERNIEBot",
"api_key": "",
"secret_key": "",
"top_p": 0.8,
"temperature": 0.9,
"penalty_score": 1.0,
"history_enable": True,
"history_max_len": 300
}

# 实例化并调用
my_wenxinworkshop = My_WenXinWorkShop(data)
logging.info(my_wenxinworkshop.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
logging.info(my_wenxinworkshop.get_resp("早上好"))
3 changes: 2 additions & 1 deletion utils/gpt_model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from utils.gpt_model.tongyi import TongYi
from utils.gpt_model.tongyixingchen import TongYiXingChen
from utils.gpt_model.my_qianfan import My_QianFan

from utils.gpt_model.my_wenxinworkshop import My_WenXinWorkShop

class GPT_Model:
openai = None
Expand All @@ -42,6 +42,7 @@ def set_model_config(self, model_name, config):
"yiyan": Yiyan,
"tongyi": TongYi,
"tongyixingchen": TongYiXingChen,
"my_wenxinworkshop": My_WenXinWorkShop,
"my_qianfan": My_QianFan
}

Expand Down
110 changes: 110 additions & 0 deletions utils/gpt_model/my_wenxinworkshop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import json, logging, traceback
from wenxinworkshop import LLMAPI, EmbeddingAPI, PromptTemplateAPI
from wenxinworkshop import Message, Messages, Texts

# 前往官网:https://cloud.baidu.com/product/wenxinworkshop 申请服务获取

class My_WenXinWorkShop:
def __init__(self, data):
# self.common = Common()
# # 日志文件路径
# file_path = "./log/log-" + self.common.get_bj_time(1) + ".txt"
# Configure_logger(file_path)

self.config_data = data
self.history = []

try:
# create a LLM API
self.my_bot = LLMAPI(
api_key=self.config_data["api_key"],
secret_key=self.config_data["secret_key"],
url=LLMAPI.ERNIEBot
)
except Exception as e:
logging.error(traceback.format_exc())



def get_resp(self, prompt):
"""请求对应接口,获取返回值
Args:
prompt (str): 你的提问
Returns:
str: 返回的文本回答
"""
try:
# create messages
messages: Messages = []

for history in self.history:
messages.append(Message(
role=history["role"],
content=history["content"]
))

messages.append(Message(
role='user',
content=prompt
))

logging.info(f"self.history={self.history}")

# get response from LLM API
resp_content = self.my_bot(
messages=messages,
temperature=self.config_data["temperature"],
top_p=self.config_data["top_p"],
penalty_score=self.config_data["penalty_score"],
stream=None,
user_id=None,
chunk_size=512
)

# 启用历史就给我记住!
if self.config_data["history_enable"]:
while True:
# 获取嵌套列表中所有字符串的字符数
total_chars = sum(len(item['content']) for item in self.history if 'content' in item)
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.config_data["history_max_len"]:
self.history.pop(0)
self.history.pop(0)
else:
# self.history.pop()
self.history.append({"role": "user", "content": prompt})
self.history.append({"role": "assistant", "content": resp_content})
break

return resp_content

except Exception as e:
logging.error(e)

return None

if __name__ == '__main__':
# 配置日志输出格式
logging.basicConfig(
level=logging.DEBUG, # 设置日志级别,可以根据需求调整
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

data = {
"model": "ERNIEBot",
"api_key": "",
"secret_key": "",
"top_p": 0.8,
"temperature": 0.9,
"penalty_score": 1.0,
"history_enable": True,
"history_max_len": 300
}

# 实例化并调用
my_wenxinworkshop = My_WenXinWorkShop(data)
logging.info(my_wenxinworkshop.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
logging.info(my_wenxinworkshop.get_resp("早上好"))
Loading

0 comments on commit c83b00d

Please sign in to comment.