forked from Ikaros-521/AI-Vtuber
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a77e79
commit d4b83c0
Showing
10 changed files
with
319 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"ruff.args": [ | ||
"--ignore=F401,E402,F541" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import json, logging | ||
import re, requests | ||
import traceback | ||
from urllib.parse import urljoin | ||
import sys | ||
sys.path.insert(1, "../../utils") | ||
#from utils.common import Common | ||
from loguru import logger | ||
|
||
|
||
class Dify: | ||
def __init__(self, data): | ||
#self.common = Common() | ||
self.config_data = data | ||
|
||
self.conversation_id = "" | ||
|
||
logger.debug(self.config_data) | ||
|
||
|
||
def get_resp(self, data): | ||
"""请求对应接口,获取返回值 | ||
Args: | ||
data (dict): 含有提问的json数据 | ||
Returns: | ||
str: 返回的文本回答 | ||
""" | ||
try: | ||
resp_content = None | ||
|
||
if self.config_data["type"] == "聊天助手": | ||
API_URL = urljoin(self.config_data["api_ip_port"], '/v1/chat-messages') | ||
|
||
data_json = { | ||
"inputs": {}, | ||
"query": data["prompt"], | ||
# 阻塞模式 | ||
"response_mode": "blocking", | ||
# 会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。 | ||
"conversation_id": self.conversation_id, | ||
# 用户名是否区分 视情况而定,暂时为了稳定性统一 | ||
"user": "test" | ||
} | ||
headers = { | ||
'Content-Type': 'application/json', | ||
'Authorization': f'Bearer {self.config_data["api_key"]}' | ||
} | ||
|
||
response = requests.request("POST", API_URL, headers=headers, json=data_json) | ||
resp_json = json.loads(response.content) | ||
|
||
logger.debug(f"resp_json={resp_json}") | ||
|
||
if "answer" in resp_json: | ||
resp_content = resp_json["answer"] | ||
|
||
# 是否记录历史 | ||
if self.config_data["history_enable"]: | ||
self.conversation_id = resp_json["conversation_id"] | ||
else: | ||
logger.error(f"获取LLM返回失败。{resp_json}") | ||
return None | ||
|
||
return resp_content | ||
|
||
except Exception as e: | ||
logger.error(traceback.format_exc()) | ||
|
||
return None | ||
|
||
if __name__ == '__main__': | ||
|
||
|
||
data = { | ||
"api_ip_port": "http://172.26.189.21/v1", | ||
"type": "聊天助手", | ||
"api_key": "app-64xu0vQjP2kxN4DKR8Ch7ZGY", | ||
"history_enable": True | ||
} | ||
|
||
# 实例化并调用 | ||
dify = Dify(data) | ||
logger.info(dify.get_resp({"prompt": "你可以扮演猫娘吗,每句话后面加个喵"})) | ||
logger.info(dify.get_resp({"prompt": "早上好"})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Iterator | ||
import time | ||
import requests | ||
import numpy as np | ||
import resampy | ||
import pyaudio | ||
|
||
class VitsTTS: | ||
audio = None | ||
stream = None | ||
sample_rate = 16000 # Set the desired sample rate for playback | ||
|
||
def __init__(self, config_json): | ||
self.config_json = config_json | ||
|
||
# Initialize PyAudio and stream if not already done | ||
if VitsTTS.audio is None: | ||
VitsTTS.audio = pyaudio.PyAudio() | ||
if VitsTTS.stream is None or not VitsTTS.stream.is_active(): | ||
VitsTTS.stream = VitsTTS.audio.open( | ||
format=pyaudio.paFloat32, | ||
channels=1, | ||
rate=VitsTTS.sample_rate, | ||
output=True | ||
) | ||
|
||
def txt_to_audio(self, msg): | ||
self.stream_tts( | ||
self.gpt_sovits( | ||
msg, | ||
self.config_json["ref_file"], | ||
self.config_json["ref_text"], | ||
"zh", # Language (can be "en" or other supported languages) | ||
self.config_json["server_url"] | ||
) | ||
) | ||
|
||
def gpt_sovits(self, text, reffile, reftext, language, server_url) -> Iterator[bytes]: | ||
start = time.perf_counter() | ||
req = { | ||
'text': text, | ||
'text_lang': language, | ||
'ref_audio_path': reffile, | ||
'prompt_text': reftext, | ||
'prompt_lang': language, | ||
'media_type': 'raw', | ||
'streaming_mode': True, | ||
} | ||
|
||
res = requests.post( | ||
f"{server_url}/tts", | ||
json=req, | ||
stream=True, | ||
) | ||
end = time.perf_counter() | ||
print(f"gpt_sovits Time to make POST: {end-start}s") | ||
|
||
if res.status_code != 200: | ||
print("Error:", res.text) | ||
return | ||
|
||
first = True | ||
for chunk in res.iter_content(chunk_size=32000): # 32K*20ms*2 | ||
if first: | ||
end = time.perf_counter() | ||
print(f"gpt_sovits Time to first chunk: {end-start}s") | ||
first = False | ||
if chunk: | ||
yield chunk | ||
|
||
print("gpt_sovits response.elapsed:", res.elapsed) | ||
|
||
def stream_tts(self, audio_stream): | ||
for chunk in audio_stream: | ||
if chunk is not None and len(chunk) > 0: | ||
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 | ||
stream = resampy.resample(x=stream, sr_orig=32000, sr_new=VitsTTS.sample_rate) | ||
VitsTTS.stream.write(stream.tobytes()) | ||
|
||
@classmethod | ||
def close_audio(cls): | ||
if cls.stream is not None: | ||
cls.stream.stop_stream() | ||
cls.stream.close() | ||
cls.audio.terminate() | ||
cls.stream = None | ||
cls.audio = None | ||
|
||
|
||
if __name__ == "__main__": | ||
config_json = { | ||
"server_url": "http://127.0.0.1:9880", | ||
"ref_file": "E:\\GitHub_pro\\AI-Vtuber\\out\\edge_tts_3.mp3", | ||
"ref_text": "就送个人气票,看不起谁呢", | ||
} | ||
vits_tts = VitsTTS(config_json) | ||
vits_tts.txt_to_audio("你好,我是AI") | ||
vits_tts.txt_to_audio("我的声音如何") | ||
vits_tts.txt_to_audio("床前明月光,疑是地上霜。举头望明月,低头思故乡。") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import json | ||
import requests | ||
import traceback | ||
from urllib.parse import urljoin | ||
from loguru import logger | ||
|
||
|
||
class Dify: | ||
def __init__(self, data: dict): | ||
self.config_data = data | ||
|
||
self.conversation_id = "" | ||
|
||
# logger.debug(self.config_data) | ||
|
||
|
||
def get_resp(self, data: dict): | ||
"""请求对应接口,获取返回值 | ||
Args: | ||
data (dict): 含有提问的json数据 | ||
Returns: | ||
str: 返回的文本回答 | ||
""" | ||
try: | ||
resp_content = None | ||
|
||
if self.config_data["type"] == "聊天助手": | ||
API_URL = urljoin(self.config_data["api_ip_port"], '/v1/chat-messages') | ||
|
||
data_json = { | ||
"inputs": {}, | ||
"query": data["prompt"], | ||
# 阻塞模式 | ||
"response_mode": "blocking", | ||
# 会话 ID,需要基于之前的聊天记录继续对话,必须传之前消息的 conversation_id。 | ||
"conversation_id": self.conversation_id, | ||
# 用户名是否区分 视情况而定,暂时为了稳定性统一 | ||
"user": "test" | ||
} | ||
headers = { | ||
'Content-Type': 'application/json', | ||
'Authorization': f'Bearer {self.config_data["api_key"]}' | ||
} | ||
|
||
response = requests.request("POST", API_URL, headers=headers, json=data_json) | ||
resp_json = json.loads(response.content) | ||
|
||
logger.debug(f"resp_json={resp_json}") | ||
|
||
if "answer" in resp_json: | ||
resp_content = resp_json["answer"] | ||
|
||
# 是否记录历史 | ||
if self.config_data["history_enable"]: | ||
self.conversation_id = resp_json["conversation_id"] | ||
else: | ||
logger.error(f"获取LLM返回失败。{resp_json}") | ||
return None | ||
|
||
return resp_content | ||
|
||
except Exception as e: | ||
logger.error(traceback.format_exc()) | ||
|
||
return None | ||
|
||
if __name__ == '__main__': | ||
data = { | ||
"api_ip_port": "http://172.26.189.21/v1", | ||
"type": "聊天助手", | ||
"api_key": "app-64xu0vQjP2kxN4DKR8Ch7ZGY", | ||
"history_enable": True | ||
} | ||
|
||
# 实例化并调用 | ||
dify = Dify(data) | ||
logger.info(dify.get_resp({"prompt": "你可以扮演猫娘吗,每句话后面加个喵"})) | ||
logger.info(dify.get_resp({"prompt": "早上好"})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.