diff --git a/config.json b/config.json index 98153330..9384ac73 100644 --- a/config.json +++ b/config.json @@ -529,6 +529,14 @@ "mode": "chat", "workspace_slug": "test" }, + "llm_tpu": { + "api_ip_port": "http://127.0.0.1:3001", + "max_length": 1, + "top_p": 0.8, + "temperature": 0.95, + "history_enable": true, + "history_max_len": 300 + }, "custom_llm": { "url": "http://127.0.0.1:11434/v1/chat/completions", "headers": "Content-Type:application/json\nAuthorization:Bearer sk", @@ -1770,6 +1778,7 @@ "koboldcpp": true, "anythingllm": true, "gpt4free": true, + "llm_tpu": true, "custom_llm": true }, "tts": { diff --git a/config.json.bak b/config.json.bak index 98153330..9384ac73 100644 --- a/config.json.bak +++ b/config.json.bak @@ -529,6 +529,14 @@ "mode": "chat", "workspace_slug": "test" }, + "llm_tpu": { + "api_ip_port": "http://127.0.0.1:3001", + "max_length": 1, + "top_p": 0.8, + "temperature": 0.95, + "history_enable": true, + "history_max_len": 300 + }, "custom_llm": { "url": "http://127.0.0.1:11434/v1/chat/completions", "headers": "Content-Type:application/json\nAuthorization:Bearer sk", @@ -1770,6 +1778,7 @@ "koboldcpp": true, "anythingllm": true, "gpt4free": true, + "llm_tpu": true, "custom_llm": true }, "tts": { diff --git a/tests/test_llm_tpu/gradio_api.py b/tests/test_llm_tpu/gradio_api.py new file mode 100644 index 00000000..2bc143b7 --- /dev/null +++ b/tests/test_llm_tpu/gradio_api.py @@ -0,0 +1,36 @@ +from gradio_client import Client +import re + +client = Client("http://127.0.0.1:8003/") + +try: + result = client.predict( + input="你可以扮演猫娘吗", + chatbot=[], + max_length=1, + top_p=0.8, + temperature=0.95, + api_name="/predict" + ) + # Assuming result[0][1] contains the response text + response_text = result[-1][1] + # Remove
and
tags using regex + cleaned_text = re.sub(r'?p>', '', response_text) + print(cleaned_text) + + result2 = client.predict( + input="你好", + chatbot=result, + max_length=1, + top_p=0.8, + temperature=0.95, + api_name="/predict" + ) + # print(result2) + # Assuming result[0][1] contains the response text + response_text = result2[-1][1] + # Removeand
tags using regex + cleaned_text = re.sub(r'?p>', '', response_text) + print(cleaned_text) +except Exception as e: + print(f"An error occurred: {e}") diff --git a/utils/gpt_model/gpt.py b/utils/gpt_model/gpt.py index 428cc42c..d6c5b8e1 100644 --- a/utils/gpt_model/gpt.py +++ b/utils/gpt_model/gpt.py @@ -31,6 +31,7 @@ from utils.gpt_model.anythingllm import AnythingLLM from utils.gpt_model.gpt4free import GPT4Free from utils.gpt_model.custom_llm import Custom_LLM +from utils.gpt_model.llm_tpu import LLM_TPU class GPT_Model: openai = None @@ -58,6 +59,7 @@ def set_model_config(self, model_name, config): "anythingllm": AnythingLLM, "gpt4free": GPT4Free, "custom_llm": Custom_LLM, + "llm_tpu": LLM_TPU, } if model_name == "openai": diff --git a/utils/gpt_model/llm_tpu.py b/utils/gpt_model/llm_tpu.py new file mode 100644 index 00000000..965905df --- /dev/null +++ b/utils/gpt_model/llm_tpu.py @@ -0,0 +1,86 @@ +import logging, traceback +from gradio_client import Client +import re + +from utils.common import Common +from utils.logger import Configure_logger + + +class LLM_TPU: + def __init__(self, data): + self.common = Common() + # 日志文件路径 + file_path = "./log/log-" + self.common.get_bj_time(1) + ".txt" + Configure_logger(file_path) + + self.config_data = data + self.history = [] + + self.history_enable = data["history_enable"] + self.history_max_len = data["history_max_len"] + + + def get_resp(self, data): + """请求对应接口,获取返回值 + + Args: + data (dict): 你的提问等 + + Returns: + str: 返回的文本回答 + """ + try: + client = Client(self.config_data["api_ip_port"]) + + result = client.predict( + input=data["prompt"], + chatbot=self.history, + max_length=self.config_data["max_length"], + top_p=self.config_data["top_p"], + temperature=self.config_data["temperature"], + api_name="/predict" + ) + + response_text = result[-1][1] + # Removeand
tags using regex + resp_content = re.sub(r'?p>', '', response_text) + + self.history = result + + # 启用历史就给我记住! + if self.history_enable: + while True: + # 获取嵌套列表中所有字符串的字符数 + total_chars = sum(len(string) for sublist in self.history for string in sublist) + # 如果大于限定最大历史数,就剔除第一个元素 + if total_chars > self.history_max_len: + self.history.pop(0) + else: + break + + return resp_content + except Exception as e: + logging.error(traceback.format_exc()) + return None + +if __name__ == '__main__': + # 配置日志输出格式 + logging.basicConfig( + level=logging.DEBUG, # 设置日志级别,可以根据需求调整 + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + data = { + "api_ip_port": "http://127.0.0.1:8003/", + "max_length": 1, + "top_p": 0.8, + "temperature": 0.95, + "history_enable": True, + "history_max_len": 300 + } + + llm_tpu = LLM_TPU(data) + logging.info(f'{llm_tpu.get_resp("你可以扮演猫娘吗,每句话后面加个喵")}') + logging.info(f'{llm_tpu.get_resp("早上好")}') + diff --git a/utils/my_handle.py b/utils/my_handle.py index 5f6a038d..4d1854b4 100644 --- a/utils/my_handle.py +++ b/utils/my_handle.py @@ -171,12 +171,14 @@ def __init__(self, config_path): self.anythingllm = None self.gpt4free = None self.custom_llm = None + self.llm_tpu = None self.image_recognition_model = None self.chat_type_list = ["chatgpt", "claude", "claude2", "chatglm", "qwen", "chat_with_file", "text_generation_webui", \ "sparkdesk", "langchain_chatglm", "langchain_chatchat", "zhipu", "bard", "yiyan", "tongyi", \ - "tongyixingchen", "my_qianfan", "my_wenxinworkshop", "gemini", "qanything", "koboldcpp", "anythingllm", "gpt4free", "custom_llm"] + "tongyixingchen", "my_qianfan", "my_wenxinworkshop", "gemini", "qanything", "koboldcpp", "anythingllm", "gpt4free", \ + "custom_llm", "llm_tpu"] # 配置加载 self.config_load() @@ -1533,6 +1535,7 @@ def llm_handle(self, chat_type, data, type="chat", webui_show=True): "anythingllm": lambda: self.anythingllm.get_resp({"prompt": data["content"]}), "gpt4free": lambda: self.gpt4free.get_resp({"prompt": data["content"]}), "custom_llm": lambda: self.custom_llm.get_resp({"prompt": data["content"]}), + "llm_tpu": lambda: self.llm_tpu.get_resp({"prompt": data["content"]}), "reread": lambda: data["content"] } elif type == "vision": diff --git a/webui.py b/webui.py index e5fd0faf..0833d045 100644 --- a/webui.py +++ b/webui.py @@ -2072,6 +2072,14 @@ def common_textarea_handle(content): config_data["custom_llm"]["data_analysis"] = textarea_custom_llm_data_analysis.value config_data["custom_llm"]["resp_template"] = textarea_custom_llm_resp_template.value + if config.get("webui", "show_card", "llm", "llm_tpu"): + config_data["llm_tpu"]["api_ip_port"] = input_llm_tpu_api_ip_port.value + config_data["llm_tpu"]["history_enable"] = switch_llm_tpu_history_enable.value + config_data["llm_tpu"]["history_max_len"] = int(input_llm_tpu_history_max_len.value) + config_data["llm_tpu"]["max_length"] = round(float(input_llm_tpu_max_length.value), 2) + config_data["llm_tpu"]["temperature"] = round(float(input_llm_tpu_temperature.value), 2) + config_data["llm_tpu"]["top_p"] = round(float(input_llm_tpu_top_p.value), 2) + """ TTS """ @@ -2619,6 +2627,7 @@ def common_textarea_handle(content): config_data["webui"]["show_card"]["llm"]["anythingllm"] = switch_webui_show_card_llm_anythingllm.value config_data["webui"]["show_card"]["llm"]["gpt4free"] = switch_webui_show_card_llm_gpt4free.value config_data["webui"]["show_card"]["llm"]["custom_llm"] = switch_webui_show_card_llm_custom_llm.value + config_data["webui"]["show_card"]["llm"]["llm_tpu"] = switch_webui_show_card_llm_llm_tpu.value config_data["webui"]["show_card"]["tts"]["edge-tts"] = switch_webui_show_card_tts_edge_tts.value config_data["webui"]["show_card"]["tts"]["vits"] = switch_webui_show_card_tts_vits.value @@ -2773,6 +2782,7 @@ def save_config(): 'anythingllm': 'AnythingLLM', 'tongyi': '通义千问', 'gpt4free': 'GPT4Free', + 'llm_tpu': 'LLM_TPU', 'custom_llm': '自定义LLM', } @@ -3033,16 +3043,16 @@ def save_config(): textarea_filter_after_must_str_for_llm = ui.textarea(label='LLM触发后缀', placeholder='后缀必须携带其中任一字符串才能触发LLM\n例如:配置。那么这个会触发:你好。', value=textarea_data_change(config.get("filter", "before_must_str_for_llm"))).style("width:200px;").tooltip('后缀必须携带其中任一字符串才能触发LLM\n例如:配置。那么这个会触发:你好。') with ui.row(): - input_filter_max_len = ui.input(label='最大单词数', placeholder='最长阅读的英文单词数(空格分隔)', value=config.get("filter", "max_len")).style("width:150px;") - input_filter_max_char_len = ui.input(label='最大单词数', placeholder='最长阅读的字符数,双重过滤,避免溢出', value=config.get("filter", "max_char_len")).style("width:150px;") - switch_filter_username_convert_digits_to_chinese = ui.switch('用户名中的数字转中文', value=config.get("filter", "username_convert_digits_to_chinese")).style(switch_internal_css) + input_filter_max_len = ui.input(label='最大单词数', placeholder='最长阅读的英文单词数(空格分隔)', value=config.get("filter", "max_len")).style("width:150px;").tooltip('最长阅读的英文单词数(空格分隔)') + input_filter_max_char_len = ui.input(label='最大单词数', placeholder='最长阅读的字符数,双重过滤,避免溢出', value=config.get("filter", "max_char_len")).style("width:150px;").tooltip('最长阅读的字符数,双重过滤,避免溢出') + switch_filter_username_convert_digits_to_chinese = ui.switch('用户名中的数字转中文', value=config.get("filter", "username_convert_digits_to_chinese")).style(switch_internal_css).tooltip('用户名中的数字转中文') switch_filter_emoji = ui.switch('弹幕表情过滤', value=config.get("filter", "emoji")).style(switch_internal_css) with ui.grid(columns=5): switch_filter_badwords_enable = ui.switch('违禁词过滤', value=config.get("filter", "badwords", "enable")).style(switch_internal_css) switch_filter_badwords_discard = ui.switch('违禁语句丢弃', value=config.get("filter", "badwords", "discard")).style(switch_internal_css) - input_filter_badwords_path = ui.input(label='违禁词路径', value=config.get("filter", "badwords", "path"), placeholder='本地违禁词数据路径(你如果不需要,可以清空文件内容)').style("width:200px;") - input_filter_badwords_bad_pinyin_path = ui.input(label='违禁拼音路径', value=config.get("filter", "badwords", "bad_pinyin_path"), placeholder='本地违禁拼音数据路径(你如果不需要,可以清空文件内容)').style("width:200px;") - input_filter_badwords_replace = ui.input(label='违禁词替换', value=config.get("filter", "badwords", "replace"), placeholder='在不丢弃违禁语句的前提下,将违禁词替换成此项的文本').style("width:200px;") + input_filter_badwords_path = ui.input(label='违禁词路径', value=config.get("filter", "badwords", "path"), placeholder='本地违禁词数据路径(你如果不需要,可以清空文件内容)').style("width:200px;").tooltip('本地违禁词数据路径(你如果不需要,可以清空文件内容)') + input_filter_badwords_bad_pinyin_path = ui.input(label='违禁拼音路径', value=config.get("filter", "badwords", "bad_pinyin_path"), placeholder='本地违禁拼音数据路径(你如果不需要,可以清空文件内容)').style("width:200px;").tooltip('本地违禁拼音数据路径(你如果不需要,可以清空文件内容)') + input_filter_badwords_replace = ui.input(label='违禁词替换', value=config.get("filter", "badwords", "replace"), placeholder='在不丢弃违禁语句的前提下,将违禁词替换成此项的文本').style("width:200px;").tooltip('在不丢弃违禁语句的前提下,将违禁词替换成此项的文本') with ui.expansion('消息遗忘&保留设置', icon="settings", value=True).classes('w-full'): with ui.element('div').classes('p-2 bg-blue-100'): @@ -4327,6 +4337,26 @@ def anythingllm_get_workspaces_list(): textarea_custom_llm_data_analysis = ui.textarea(label=f"数据解析(eval执行)", value=config.get("custom_llm", "data_analysis"), placeholder='数据解析,请不要随意修改resp变量,会被用于最后返回数据内容的解析').style("width:300px;").tooltip('数据解析,请不要随意修改resp变量,会被用于最后返回数据内容的解析') textarea_custom_llm_resp_template = ui.textarea(label=f"返回内容模板", value=config.get("custom_llm", "resp_template"), placeholder='请不要随意删除data变量,支持动态变量,最终会合并成完成内容进行音频合成').style("width:300px;").tooltip('请不要随意删除data变量,支持动态变量,最终会合并成完成内容进行音频合成') + if config.get("webui", "show_card", "llm", "llm_tpu"): + with ui.card().style(card_css): + ui.label("LLM_TPU") + with ui.row(): + input_llm_tpu_api_ip_port = ui.input( + label='API地址', + value=config.get("llm_tpu", "api_ip_port"), + placeholder='llm_tpu启动gradio web demo后监听的ip端口地址', + validation={ + '请输入正确格式的URL': lambda value: common.is_url_check(value), + } + ) + switch_llm_tpu_history_enable = ui.switch('上下文记忆', value=config.get("llm_tpu", "history_enable")).style(switch_internal_css) + input_llm_tpu_history_max_len = ui.input(label='最大记忆长度', value=config.get("llm_tpu", "history_max_len"), placeholder='最长能记忆的问答字符串长度,超长会丢弃最早记忆的内容,请慎用!配置过大可能会有丢大米') + + with ui.row(): + input_llm_tpu_max_length = ui.input(label='max_length', value=config.get("llm_tpu", "max_length"), placeholder='max_length').style("width:200px;") + input_llm_tpu_temperature = ui.input(label='温度', value=config.get("llm_tpu", "temperature"), placeholder='(0, 1.0] 控制生成文本的随机性。较高的温度值会使生成的文本更随机和多样化,而较低的温度值会使生成的文本更加确定和一致。').style("width:200px;") + input_llm_tpu_top_p = ui.input(label='前p个选择', value=config.get("llm_tpu", "top_p"), placeholder='[0, 1.0] Nucleus采样。这个参数控制模型从累积概率大于一定阈值的令牌中进行采样。较高的值会产生更多的多样性,较低的值会产生更少但更确定的回答。').style("width:200px;") + with ui.tab_panel(tts_page).style(tab_panel_css): # 通用-合成试听音频 async def tts_common_audio_synthesis(): @@ -5991,6 +6021,7 @@ def update_echart_gift(): switch_webui_show_card_llm_anythingllm = ui.switch('AnythingLLM', value=config.get("webui", "show_card", "llm", "anythingllm")).style(switch_internal_css) switch_webui_show_card_llm_gpt4free = ui.switch('GPT4Free', value=config.get("webui", "show_card", "llm", "gpt4free")).style(switch_internal_css) switch_webui_show_card_llm_custom_llm = ui.switch('自定义LLM', value=config.get("webui", "show_card", "llm", "custom_llm")).style(switch_internal_css) + switch_webui_show_card_llm_llm_tpu = ui.switch('LLM_TPU', value=config.get("webui", "show_card", "llm", "llm_tpu")).style(switch_internal_css) with ui.card().style(card_css): ui.label("文本转语音")