Skip to content

Commit

Permalink
Merge pull request Ikaros-521#556 from Ikaros-521/owner
Browse files Browse the repository at this point in the history
text-gen新增官方api的对接并新增传参top_p、top_k、温度、seed
  • Loading branch information
Ikaros-521 authored Jan 6, 2024
2 parents 8cc4319 + 55775cb commit 30e665f
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 74 deletions.
4 changes: 4 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@
"character": "Example",
"instruction_template": "Vicuna-v1.1",
"your_name": "主人",
"top_p": 1.0,
"top_k": 40,
"temperature": 0.7,
"seed": -1.0,
"history_enable": true,
"history_max_len": 500
},
Expand Down
4 changes: 4 additions & 0 deletions config.json.bak
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@
"character": "Example",
"instruction_template": "Vicuna-v1.1",
"your_name": "主人",
"top_p": 1.0,
"top_k": 40,
"temperature": 0.7,
"seed": -1.0,
"history_enable": true,
"history_max_len": 500
},
Expand Down
237 changes: 172 additions & 65 deletions utils/gpt_model/text_generation_webui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json, logging
import requests
import json, logging, traceback
import requests, re
from urllib.parse import urljoin

from utils.common import Common
from utils.logger import Configure_logger
Expand All @@ -25,7 +26,10 @@ def __init__(self, data):
self.history_enable = data["history_enable"]
self.history_max_len = data["history_max_len"]

self.history = {"internal": [], "visible": []}
if self.config_data["type"] == "coyude":
self.history = {"internal": [], "visible": []}
else:
self.history = []

# 合并函数
def merge_jsons(self, json_list):
Expand All @@ -48,82 +52,185 @@ def remove_first_group(self, json_obj):
return json_obj

def get_resp(self, user_input):
request = {
'user_input': user_input,
'max_new_tokens': self.max_new_tokens,
'history': self.history,
'mode': self.mode, # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': self.character, # 'TavernAI-Gawr Gura'
'instruction_template': self.instruction_template,
'your_name': self.your_name,

'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
if self.config_data["type"] == "coyude":
request = {
'user_input': user_input,
'max_new_tokens': self.max_new_tokens,
'history': self.history,
'mode': self.mode, # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': self.character, # 'TavernAI-Gawr Gura'
'instruction_template': self.instruction_template,
'your_name': self.your_name,

# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'regenerate': False,
'_continue': False,
'stop_at_newline': False,
'chat_generation_attempts': 1,
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',

'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,

try:
response = requests.post(self.api_ip_port + "/api/v1/chat", json=request)
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': []
}

if response.status_code == 200:
result = response.json()['results'][0]['history']
# logging.info(json.dumps(result, indent=4))
# print(result['visible'][-1][1])
resp_content = result['visible'][-1][1]
try:
url = urljoin(self.api_ip_port, "/api/v1/chat")
response = requests.post(url=url, json=request)

if response.status_code == 200:
result = response.json()['results'][0]['history']
# logging.info(json.dumps(result, indent=4))
# print(result['visible'][-1][1])
resp_content = result['visible'][-1][1]

# 启用历史就给我记住!
if self.history_enable:
while True:
# 统计字符数
total_chars = sum(len(item) for sublist in self.history['internal'] for item in sublist)
total_chars += sum(len(item) for sublist in self.history['visible'] for item in sublist)
logging.info(f"total_chars={total_chars}")
# 如果大于限定最大历史数,就剔除第一个元素
if total_chars > self.history_max_len:
self.history = self.remove_first_group(self.history)
else:
self.history = result
break

return resp_content
else:
return None
except Exception as e:
logging.error(traceback.format_exc())
return None
else:
try:
url = urljoin(self.api_ip_port, "/v1/chat/completions")

headers = {
"Content-Type": "application/json"
}

self.history.append({"role": "user", "content": user_input})

data = {
"messages": self.history,
"temperature": self.config_data["temperature"],
"top_p": self.config_data["top_p"],
"top_k": self.config_data["top_k"],
# "user": "string",
"mode": self.mode,
"character": self.character,
"max_tokens": self.max_new_tokens,
"instruction_template": self.instruction_template,
"stream": False,
"seed": self.config_data["seed"],
# "preset": self.config_data["preset"],
# "stop": "string",
# "n": 1,
# "presence_penalty": 0,
# "model": "string",
# "instruction_template_str": "string",
# "frequency_penalty": 0,
# "function_call": "string",
# "functions": [
# {}
# ],
# "logit_bias": {},
# "name1": "string",
# "name2": "string",
# "context": "string",
# "greeting": "string",
# "chat_template_str": "string",
# "chat_instruct_command": "string",
# "continue_": False,
# "min_p": 0,
# "repetition_penalty": 1,
# "repetition_penalty_range": 1024,
# "typical_p": 1,
# "tfs": 1,
# "top_a": 0,
# "epsilon_cutoff": 0,
# "eta_cutoff": 0,
# "guidance_scale": 1,
# "negative_prompt": "",
# "penalty_alpha": 0,
# "mirostat_mode": 0,
# "mirostat_tau": 5,
# "mirostat_eta": 0.1,
# "temperature_last": False,
# "do_sample": True,
# "encoder_repetition_penalty": 1,
# "no_repeat_ngram_size": 0,
# "min_length": 0,
# "num_beams": 1,
# "length_penalty": 1,
# "early_stopping": False,
# "truncation_length": 0,
# "max_tokens_second": 0,
# "custom_token_bans": "",
# "auto_max_new_tokens": False,
# "ban_eos_token": False,
# "add_bos_token": True,
# "skip_special_tokens": True,
# "grammar_string": ""
}

logging.debug(data)

response = requests.post(url, headers=headers, json=data, verify=False)
resp_json = response.json()
logging.debug(resp_json)

resp_content = resp_json['choices'][0]['message']['content']
# 过滤多余的 \n
resp_content = re.sub(r'\n+', '\n', resp_content)
# 从字符串的两端或者一端删除指定的字符,默认是空格或者换行符
resp_content = resp_content.rstrip('\n')
self.history.append({"role": "assistant", "content": resp_content})

# 启用历史就给我记住!
if self.history_enable:
while True:
# 统计字符数
total_chars = sum(len(item) for sublist in self.history['internal'] for item in sublist)
total_chars += sum(len(item) for sublist in self.history['visible'] for item in sublist)
logging.info(f"total_chars={total_chars}")
# 如果大于限定最大历史数,就剔除第一个元素
total_chars = sum(len(i['content']) for i in self.history)
# 如果大于限定最大历史数,就剔除第1 个元素
if total_chars > self.history_max_len:
self.history = self.remove_first_group(self.history)
self.history.pop(0)
# self.history.pop(0)
else:
self.history = result
break

return resp_content
else:
except Exception as e:
logging.error(traceback.format_exc())
return None
except Exception as e:
logging.info(e)
return None


# 源于官方 api-example.py
Expand Down
4 changes: 2 additions & 2 deletions utils/gpt_model/tongyi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import revTongYi
import json, logging
import json, logging, traceback

from utils.common import Common
from utils.logger import Configure_logger
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_resp(self, prompt):

return ret["content"][0]
except Exception as e:
logging.error(e)
logging.error(traceback.format_exc())
return None


Expand Down
2 changes: 2 additions & 0 deletions utils/my_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ def llm_handle(self, chat_type, data):
# 使用字典映射的方式来获取响应内容
resp_content = chat_model_methods.get(chat_type, lambda: data["content"])()

logging.debug(f"resp_content={resp_content}")

if resp_content is None:
My_handle.abnormal_alarm_data["llm"]["error_count"] += 1
self.abnormal_alarm_handle("llm")
Expand Down
27 changes: 20 additions & 7 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,14 @@ def common_textarea_handle(content):
config_data["text_generation_webui"]["max_new_tokens"] = int(input_text_generation_webui_max_new_tokens.value)
config_data["text_generation_webui"]["history_enable"] = switch_text_generation_webui_history_enable.value
config_data["text_generation_webui"]["history_max_len"] = int(input_text_generation_webui_history_max_len.value)
config_data["text_generation_webui"]["mode"] = input_text_generation_webui_mode.value
config_data["text_generation_webui"]["mode"] = select_text_generation_webui_mode.value
config_data["text_generation_webui"]["character"] = input_text_generation_webui_character.value
config_data["text_generation_webui"]["instruction_template"] = input_text_generation_webui_instruction_template.value
config_data["text_generation_webui"]["your_name"] = input_text_generation_webui_your_name.value
config_data["text_generation_webui"]["top_p"] = round(float(input_text_generation_webui_top_p.value), 2)
config_data["text_generation_webui"]["top_k"] = int(input_text_generation_webui_top_k.value)
config_data["text_generation_webui"]["temperature"] = round(float(input_text_generation_webui_temperature.value), 2)
config_data["text_generation_webui"]["seed"] = float(input_text_generation_webui_seed.value)

config_data["sparkdesk"]["type"] = select_sparkdesk_type.value
config_data["sparkdesk"]["cookie"] = input_sparkdesk_cookie.value
Expand Down Expand Up @@ -1827,7 +1831,7 @@ def common_textarea_handle(content):
with ui.row():
select_text_generation_webui_type = ui.select(
label='类型',
options={"coyude": "coyude"},
options={"官方API": "官方API", "coyude": "coyude"},
value=config.get("text_generation_webui", "type")
)
input_text_generation_webui_api_ip_port = ui.input(label='API地址', placeholder='text-generation-webui开启API模式后监听的IP和端口地址', value=config.get("text_generation_webui", "api_ip_port"))
Expand All @@ -1838,14 +1842,23 @@ def common_textarea_handle(content):
input_text_generation_webui_history_max_len = ui.input(label='最大记忆长度', placeholder='最大记忆的上下文字符数量,不建议设置过大,容易爆显存,自行根据情况配置', value=config.get("text_generation_webui", "history_max_len"))
input_text_generation_webui_history_max_len.style("width:200px")
with ui.row():
input_text_generation_webui_mode = ui.input(label='模式', placeholder='自行查阅', value=config.get("text_generation_webui", "mode"))
input_text_generation_webui_mode.style("width:300px")
select_text_generation_webui_mode = ui.select(
label='类型',
options={"chat": "chat", "chat-instruct": "chat-instruct", "instruct": "instruct"},
value=config.get("text_generation_webui", "mode")
).style("width:150px")
input_text_generation_webui_character = ui.input(label='character', placeholder='自行查阅', value=config.get("text_generation_webui", "character"))
input_text_generation_webui_character.style("width:300px")
input_text_generation_webui_character.style("width:100px")
input_text_generation_webui_instruction_template = ui.input(label='instruction_template', placeholder='自行查阅', value=config.get("text_generation_webui", "instruction_template"))
input_text_generation_webui_instruction_template.style("width:300px")
input_text_generation_webui_instruction_template.style("width:150px")
input_text_generation_webui_your_name = ui.input(label='your_name', placeholder='自行查阅', value=config.get("text_generation_webui", "your_name"))
input_text_generation_webui_your_name.style("width:300px")
input_text_generation_webui_your_name.style("width:100px")
with ui.row():
input_text_generation_webui_top_p = ui.input(label='top_p', value=config.get("text_generation_webui", "top_p"), placeholder='topP生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。默认值 0.95。注意,取值不要大于等于1')
input_text_generation_webui_top_k = ui.input(label='top_k', value=config.get("text_generation_webui", "top_k"), placeholder='匹配搜索结果条数')
input_text_generation_webui_temperature = ui.input(label='temperature', value=config.get("text_generation_webui", "temperature"), placeholder='较高的值将使输出更加随机,而较低的值将使输出更加集中和确定。可选,默认取值0.92')
input_text_generation_webui_seed = ui.input(label='seed', value=config.get("text_generation_webui", "seed"), placeholder='seed生成时,随机数的种子,用于控制模型生成的随机性。如果使用相同的种子,每次运行生成的结果都将相同;当需要复现模型的生成结果时,可以使用相同的种子。seed参数支持无符号64位整数类型。默认值 1683806810')

with ui.card().style(card_css):
ui.label("讯飞星火")
with ui.grid(columns=2):
Expand Down

0 comments on commit 30e665f

Please sign in to comment.