forked from Ikaros-521/AI-Vtuber
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yiyan.py
157 lines (120 loc) · 5.51 KB
/
yiyan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import requests, time
from requests.exceptions import ConnectionError, RequestException
from utils.common import Common
from utils.my_log import logger
# 原计划对接:https://github.com/zhuweiyou/yiyan-api
class Yiyan:
def __init__(self, data):
self.common = Common()
self.config_data = data
self.type = data["type"]
self.history = []
def get_access_token(self):
"""
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
"""
url = f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={self.config_data["api"]["api_key"]}&client_secret={self.config_data["api"]["secret_key"]}'
payload = json.dumps("")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
return response.json().get("access_token")
def get_resp(self, prompt):
"""请求对应接口,获取返回值
Args:
prompt (str): 你的提问
Returns:
str: 返回的文本回答
"""
try:
if self.type == "web":
try:
data_json = {
"cookie": self.config_data["web"]["cookie"],
"prompt": prompt
}
# logger.debug(data_json)
url = self.config_data["web"]["api_ip_port"] + "/headless"
response = requests.post(url=url, data=data_json)
response.raise_for_status() # 检查响应的状态码
result = response.content
ret = json.loads(result)
logger.debug(ret)
resp_content = ret['text'].replace('\n', '').replace('\\n', '')
# 启用历史就给我记住!
if self.config_data["history_enable"]:
while True:
# 获取嵌套列表中所有字符串的字符数
total_chars = sum(len(string) for sublist in self.history for string in sublist)
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.config_data["history_max_len"]:
self.history.pop(0)
else:
self.history.append({"role": "user", "content": prompt})
self.history.append({"role": "assistant", "content": resp_content})
break
return resp_content
except ConnectionError as ce:
# 处理连接问题异常
logger.error(f"请检查你是否启动了服务端或配置是否匹配,连接异常:{ce}")
except RequestException as re:
# 处理其他请求异常
logger.error(f"请求异常:{re}")
except Exception as e:
logger.error(e)
else:
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=" + self.get_access_token()
data_json = {
"messages": self.history + [{"role": "user", "content": prompt}]
}
payload = json.dumps(data_json)
headers = {
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
logger.info(payload)
logger.info(response.text)
resp_content = json.loads(response.text)["result"]
# 启用历史就给我记住!
if self.config_data["history_enable"]:
while True:
# 获取嵌套列表中所有字符串的字符数
total_chars = sum(len(string) for sublist in self.history for string in sublist)
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.config_data["history_max_len"]:
self.history.pop(0)
else:
self.history.append({"role": "user", "content": prompt})
self.history.append({"role": "assistant", "content": resp_content})
break
return resp_content
except Exception as e:
logger.error(e)
return None
if __name__ == '__main__':
# 配置日志输出格式
logger.basicConfig(
level=logger.DEBUG, # 设置日志级别,可以根据需求调整
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
data = {
"type": 'api',
"web": {
"api_ip_port": "http://127.0.0.1:3000",
"cookie": ''
},
"api": {
"api_key": "",
"secret_key": ""
},
"history_enable": True,
"history_max_len": 300
}
yiyan = Yiyan(data)
logger.info(yiyan.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
time.sleep(1)
logger.info(yiyan.get_resp("早上好"))