Skip to content

Commit

Permalink
LLM新增 LLM_TPU 的gradio对接
Browse files Browse the repository at this point in the history
  • Loading branch information
Ikaros-521 committed Jun 16, 2024
1 parent c1abef2 commit ade0a73
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 7 deletions.
9 changes: 9 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,14 @@
"mode": "chat",
"workspace_slug": "test"
},
"llm_tpu": {
"api_ip_port": "http://127.0.0.1:3001",
"max_length": 1,
"top_p": 0.8,
"temperature": 0.95,
"history_enable": true,
"history_max_len": 300
},
"custom_llm": {
"url": "http://127.0.0.1:11434/v1/chat/completions",
"headers": "Content-Type:application/json\nAuthorization:Bearer sk",
Expand Down Expand Up @@ -1770,6 +1778,7 @@
"koboldcpp": true,
"anythingllm": true,
"gpt4free": true,
"llm_tpu": true,
"custom_llm": true
},
"tts": {
Expand Down
9 changes: 9 additions & 0 deletions config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,14 @@
"mode": "chat",
"workspace_slug": "test"
},
"llm_tpu": {
"api_ip_port": "http://127.0.0.1:3001",
"max_length": 1,
"top_p": 0.8,
"temperature": 0.95,
"history_enable": true,
"history_max_len": 300
},
"custom_llm": {
"url": "http://127.0.0.1:11434/v1/chat/completions",
"headers": "Content-Type:application/json\nAuthorization:Bearer sk",
Expand Down Expand Up @@ -1770,6 +1778,7 @@
"koboldcpp": true,
"anythingllm": true,
"gpt4free": true,
"llm_tpu": true,
"custom_llm": true
},
"tts": {
Expand Down
36 changes: 36 additions & 0 deletions tests/test_llm_tpu/gradio_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from gradio_client import Client
import re

client = Client("http://127.0.0.1:8003/")

try:
result = client.predict(
input="你可以扮演猫娘吗",
chatbot=[],
max_length=1,
top_p=0.8,
temperature=0.95,
api_name="/predict"
)
# Assuming result[0][1] contains the response text
response_text = result[-1][1]
# Remove <p> and </p> tags using regex
cleaned_text = re.sub(r'</?p>', '', response_text)
print(cleaned_text)

result2 = client.predict(
input="你好",
chatbot=result,
max_length=1,
top_p=0.8,
temperature=0.95,
api_name="/predict"
)
# print(result2)
# Assuming result[0][1] contains the response text
response_text = result2[-1][1]
# Remove <p> and </p> tags using regex
cleaned_text = re.sub(r'</?p>', '', response_text)
print(cleaned_text)
except Exception as e:
print(f"An error occurred: {e}")
2 changes: 2 additions & 0 deletions utils/gpt_model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from utils.gpt_model.anythingllm import AnythingLLM
from utils.gpt_model.gpt4free import GPT4Free
from utils.gpt_model.custom_llm import Custom_LLM
from utils.gpt_model.llm_tpu import LLM_TPU

class GPT_Model:
openai = None
Expand Down Expand Up @@ -58,6 +59,7 @@ def set_model_config(self, model_name, config):
"anythingllm": AnythingLLM,
"gpt4free": GPT4Free,
"custom_llm": Custom_LLM,
"llm_tpu": LLM_TPU,
}

if model_name == "openai":
Expand Down
86 changes: 86 additions & 0 deletions utils/gpt_model/llm_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging, traceback
from gradio_client import Client
import re

from utils.common import Common
from utils.logger import Configure_logger


class LLM_TPU:
def __init__(self, data):
self.common = Common()
# 日志文件路径
file_path = "./log/log-" + self.common.get_bj_time(1) + ".txt"
Configure_logger(file_path)

self.config_data = data
self.history = []

self.history_enable = data["history_enable"]
self.history_max_len = data["history_max_len"]


def get_resp(self, data):
"""请求对应接口,获取返回值
Args:
data (dict): 你的提问等
Returns:
str: 返回的文本回答
"""
try:
client = Client(self.config_data["api_ip_port"])

result = client.predict(
input=data["prompt"],
chatbot=self.history,
max_length=self.config_data["max_length"],
top_p=self.config_data["top_p"],
temperature=self.config_data["temperature"],
api_name="/predict"
)

response_text = result[-1][1]
# Remove <p> and </p> tags using regex
resp_content = re.sub(r'</?p>', '', response_text)

self.history = result

# 启用历史就给我记住!
if self.history_enable:
while True:
# 获取嵌套列表中所有字符串的字符数
total_chars = sum(len(string) for sublist in self.history for string in sublist)
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.history_max_len:
self.history.pop(0)
else:
break

return resp_content
except Exception as e:
logging.error(traceback.format_exc())
return None

if __name__ == '__main__':
# 配置日志输出格式
logging.basicConfig(
level=logging.DEBUG, # 设置日志级别,可以根据需求调整
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

data = {
"api_ip_port": "http://127.0.0.1:8003/",
"max_length": 1,
"top_p": 0.8,
"temperature": 0.95,
"history_enable": True,
"history_max_len": 300
}

llm_tpu = LLM_TPU(data)
logging.info(f'{llm_tpu.get_resp("你可以扮演猫娘吗,每句话后面加个喵")}')
logging.info(f'{llm_tpu.get_resp("早上好")}')

5 changes: 4 additions & 1 deletion utils/my_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ def __init__(self, config_path):
self.anythingllm = None
self.gpt4free = None
self.custom_llm = None
self.llm_tpu = None

self.image_recognition_model = None

self.chat_type_list = ["chatgpt", "claude", "claude2", "chatglm", "qwen", "chat_with_file", "text_generation_webui", \
"sparkdesk", "langchain_chatglm", "langchain_chatchat", "zhipu", "bard", "yiyan", "tongyi", \
"tongyixingchen", "my_qianfan", "my_wenxinworkshop", "gemini", "qanything", "koboldcpp", "anythingllm", "gpt4free", "custom_llm"]
"tongyixingchen", "my_qianfan", "my_wenxinworkshop", "gemini", "qanything", "koboldcpp", "anythingllm", "gpt4free", \
"custom_llm", "llm_tpu"]

# 配置加载
self.config_load()
Expand Down Expand Up @@ -1533,6 +1535,7 @@ def llm_handle(self, chat_type, data, type="chat", webui_show=True):
"anythingllm": lambda: self.anythingllm.get_resp({"prompt": data["content"]}),
"gpt4free": lambda: self.gpt4free.get_resp({"prompt": data["content"]}),
"custom_llm": lambda: self.custom_llm.get_resp({"prompt": data["content"]}),
"llm_tpu": lambda: self.llm_tpu.get_resp({"prompt": data["content"]}),
"reread": lambda: data["content"]
}
elif type == "vision":
Expand Down
43 changes: 37 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2072,6 +2072,14 @@ def common_textarea_handle(content):
config_data["custom_llm"]["data_analysis"] = textarea_custom_llm_data_analysis.value
config_data["custom_llm"]["resp_template"] = textarea_custom_llm_resp_template.value

if config.get("webui", "show_card", "llm", "llm_tpu"):
config_data["llm_tpu"]["api_ip_port"] = input_llm_tpu_api_ip_port.value
config_data["llm_tpu"]["history_enable"] = switch_llm_tpu_history_enable.value
config_data["llm_tpu"]["history_max_len"] = int(input_llm_tpu_history_max_len.value)
config_data["llm_tpu"]["max_length"] = round(float(input_llm_tpu_max_length.value), 2)
config_data["llm_tpu"]["temperature"] = round(float(input_llm_tpu_temperature.value), 2)
config_data["llm_tpu"]["top_p"] = round(float(input_llm_tpu_top_p.value), 2)

"""
TTS
"""
Expand Down Expand Up @@ -2619,6 +2627,7 @@ def common_textarea_handle(content):
config_data["webui"]["show_card"]["llm"]["anythingllm"] = switch_webui_show_card_llm_anythingllm.value
config_data["webui"]["show_card"]["llm"]["gpt4free"] = switch_webui_show_card_llm_gpt4free.value
config_data["webui"]["show_card"]["llm"]["custom_llm"] = switch_webui_show_card_llm_custom_llm.value
config_data["webui"]["show_card"]["llm"]["llm_tpu"] = switch_webui_show_card_llm_llm_tpu.value

config_data["webui"]["show_card"]["tts"]["edge-tts"] = switch_webui_show_card_tts_edge_tts.value
config_data["webui"]["show_card"]["tts"]["vits"] = switch_webui_show_card_tts_vits.value
Expand Down Expand Up @@ -2773,6 +2782,7 @@ def save_config():
'anythingllm': 'AnythingLLM',
'tongyi': '通义千问',
'gpt4free': 'GPT4Free',
'llm_tpu': 'LLM_TPU',
'custom_llm': '自定义LLM',
}

Expand Down Expand Up @@ -3033,16 +3043,16 @@ def save_config():
textarea_filter_after_must_str_for_llm = ui.textarea(label='LLM触发后缀', placeholder='后缀必须携带其中任一字符串才能触发LLM\n例如:配置。那么这个会触发:你好。', value=textarea_data_change(config.get("filter", "before_must_str_for_llm"))).style("width:200px;").tooltip('后缀必须携带其中任一字符串才能触发LLM\n例如:配置。那么这个会触发:你好。')

with ui.row():
input_filter_max_len = ui.input(label='最大单词数', placeholder='最长阅读的英文单词数(空格分隔)', value=config.get("filter", "max_len")).style("width:150px;")
input_filter_max_char_len = ui.input(label='最大单词数', placeholder='最长阅读的字符数,双重过滤,避免溢出', value=config.get("filter", "max_char_len")).style("width:150px;")
switch_filter_username_convert_digits_to_chinese = ui.switch('用户名中的数字转中文', value=config.get("filter", "username_convert_digits_to_chinese")).style(switch_internal_css)
input_filter_max_len = ui.input(label='最大单词数', placeholder='最长阅读的英文单词数(空格分隔)', value=config.get("filter", "max_len")).style("width:150px;").tooltip('最长阅读的英文单词数(空格分隔)')
input_filter_max_char_len = ui.input(label='最大单词数', placeholder='最长阅读的字符数,双重过滤,避免溢出', value=config.get("filter", "max_char_len")).style("width:150px;").tooltip('最长阅读的字符数,双重过滤,避免溢出')
switch_filter_username_convert_digits_to_chinese = ui.switch('用户名中的数字转中文', value=config.get("filter", "username_convert_digits_to_chinese")).style(switch_internal_css).tooltip('用户名中的数字转中文')
switch_filter_emoji = ui.switch('弹幕表情过滤', value=config.get("filter", "emoji")).style(switch_internal_css)
with ui.grid(columns=5):
switch_filter_badwords_enable = ui.switch('违禁词过滤', value=config.get("filter", "badwords", "enable")).style(switch_internal_css)
switch_filter_badwords_discard = ui.switch('违禁语句丢弃', value=config.get("filter", "badwords", "discard")).style(switch_internal_css)
input_filter_badwords_path = ui.input(label='违禁词路径', value=config.get("filter", "badwords", "path"), placeholder='本地违禁词数据路径(你如果不需要,可以清空文件内容)').style("width:200px;")
input_filter_badwords_bad_pinyin_path = ui.input(label='违禁拼音路径', value=config.get("filter", "badwords", "bad_pinyin_path"), placeholder='本地违禁拼音数据路径(你如果不需要,可以清空文件内容)').style("width:200px;")
input_filter_badwords_replace = ui.input(label='违禁词替换', value=config.get("filter", "badwords", "replace"), placeholder='在不丢弃违禁语句的前提下,将违禁词替换成此项的文本').style("width:200px;")
input_filter_badwords_path = ui.input(label='违禁词路径', value=config.get("filter", "badwords", "path"), placeholder='本地违禁词数据路径(你如果不需要,可以清空文件内容)').style("width:200px;").tooltip('本地违禁词数据路径(你如果不需要,可以清空文件内容)')
input_filter_badwords_bad_pinyin_path = ui.input(label='违禁拼音路径', value=config.get("filter", "badwords", "bad_pinyin_path"), placeholder='本地违禁拼音数据路径(你如果不需要,可以清空文件内容)').style("width:200px;").tooltip('本地违禁拼音数据路径(你如果不需要,可以清空文件内容)')
input_filter_badwords_replace = ui.input(label='违禁词替换', value=config.get("filter", "badwords", "replace"), placeholder='在不丢弃违禁语句的前提下,将违禁词替换成此项的文本').style("width:200px;").tooltip('在不丢弃违禁语句的前提下,将违禁词替换成此项的文本')

with ui.expansion('消息遗忘&保留设置', icon="settings", value=True).classes('w-full'):
with ui.element('div').classes('p-2 bg-blue-100'):
Expand Down Expand Up @@ -4327,6 +4337,26 @@ def anythingllm_get_workspaces_list():
textarea_custom_llm_data_analysis = ui.textarea(label=f"数据解析(eval执行)", value=config.get("custom_llm", "data_analysis"), placeholder='数据解析,请不要随意修改resp变量,会被用于最后返回数据内容的解析').style("width:300px;").tooltip('数据解析,请不要随意修改resp变量,会被用于最后返回数据内容的解析')
textarea_custom_llm_resp_template = ui.textarea(label=f"返回内容模板", value=config.get("custom_llm", "resp_template"), placeholder='请不要随意删除data变量,支持动态变量,最终会合并成完成内容进行音频合成').style("width:300px;").tooltip('请不要随意删除data变量,支持动态变量,最终会合并成完成内容进行音频合成')

if config.get("webui", "show_card", "llm", "llm_tpu"):
with ui.card().style(card_css):
ui.label("LLM_TPU")
with ui.row():
input_llm_tpu_api_ip_port = ui.input(
label='API地址',
value=config.get("llm_tpu", "api_ip_port"),
placeholder='llm_tpu启动gradio web demo后监听的ip端口地址',
validation={
'请输入正确格式的URL': lambda value: common.is_url_check(value),
}
)
switch_llm_tpu_history_enable = ui.switch('上下文记忆', value=config.get("llm_tpu", "history_enable")).style(switch_internal_css)
input_llm_tpu_history_max_len = ui.input(label='最大记忆长度', value=config.get("llm_tpu", "history_max_len"), placeholder='最长能记忆的问答字符串长度,超长会丢弃最早记忆的内容,请慎用!配置过大可能会有丢大米')

with ui.row():
input_llm_tpu_max_length = ui.input(label='max_length', value=config.get("llm_tpu", "max_length"), placeholder='max_length').style("width:200px;")
input_llm_tpu_temperature = ui.input(label='温度', value=config.get("llm_tpu", "temperature"), placeholder='(0, 1.0] 控制生成文本的随机性。较高的温度值会使生成的文本更随机和多样化,而较低的温度值会使生成的文本更加确定和一致。').style("width:200px;")
input_llm_tpu_top_p = ui.input(label='前p个选择', value=config.get("llm_tpu", "top_p"), placeholder='[0, 1.0] Nucleus采样。这个参数控制模型从累积概率大于一定阈值的令牌中进行采样。较高的值会产生更多的多样性,较低的值会产生更少但更确定的回答。').style("width:200px;")

with ui.tab_panel(tts_page).style(tab_panel_css):
# 通用-合成试听音频
async def tts_common_audio_synthesis():
Expand Down Expand Up @@ -5991,6 +6021,7 @@ def update_echart_gift():
switch_webui_show_card_llm_anythingllm = ui.switch('AnythingLLM', value=config.get("webui", "show_card", "llm", "anythingllm")).style(switch_internal_css)
switch_webui_show_card_llm_gpt4free = ui.switch('GPT4Free', value=config.get("webui", "show_card", "llm", "gpt4free")).style(switch_internal_css)
switch_webui_show_card_llm_custom_llm = ui.switch('自定义LLM', value=config.get("webui", "show_card", "llm", "custom_llm")).style(switch_internal_css)
switch_webui_show_card_llm_llm_tpu = ui.switch('LLM_TPU', value=config.get("webui", "show_card", "llm", "llm_tpu")).style(switch_internal_css)

with ui.card().style(card_css):
ui.label("文本转语音")
Expand Down

0 comments on commit ade0a73

Please sign in to comment.