forked from 2noise/ChatTTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 8a0e837
Showing
12 changed files
with
1,327 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
*.ckpt | ||
# C extensions | ||
*.so | ||
*.pt | ||
|
||
# Distribution / packaging | ||
.Python | ||
outputs/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
asset/* | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/#use-with-ide | ||
.pdm.toml | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .core import Chat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
|
||
import os | ||
import logging | ||
from omegaconf import OmegaConf | ||
|
||
import torch | ||
from vocos import Vocos | ||
from .model.dvae import DVAE | ||
from .model.gpt import GPT_warpper | ||
from .utils.gpu_utils import select_device | ||
from .infer.api import refine_text, infer_code | ||
|
||
from huggingface_hub import snapshot_download | ||
|
||
logging.basicConfig(level = logging.INFO) | ||
|
||
|
||
class Chat: | ||
def __init__(self, ): | ||
self.pretrain_models = {} | ||
self.logger = logging.getLogger(__name__) | ||
|
||
def check_model(self, level = logging.INFO, use_decoder = False): | ||
not_finish = False | ||
check_list = ['vocos', 'gpt', 'tokenizer'] | ||
|
||
if use_decoder: | ||
check_list.append('decoder') | ||
else: | ||
check_list.append('dvae') | ||
|
||
for module in check_list: | ||
if module not in self.pretrain_models: | ||
self.logger.log(logging.WARNING, f'{module} not initialized.') | ||
not_finish = True | ||
|
||
if not not_finish: | ||
self.logger.log(level, f'All initialized.') | ||
|
||
return not not_finish | ||
|
||
def load_models(self, source='huggingface'): | ||
if source == 'huggingface': | ||
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) | ||
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}) | ||
|
||
def _load( | ||
self, | ||
vocos_config_path: str = None, | ||
vocos_ckpt_path: str = None, | ||
dvae_config_path: str = None, | ||
dvae_ckpt_path: str = None, | ||
gpt_config_path: str = None, | ||
gpt_ckpt_path: str = None, | ||
decoder_config_path: str = None, | ||
decoder_ckpt_path: str = None, | ||
tokenizer_path: str = None, | ||
device: str = None | ||
): | ||
if not device: | ||
device = select_device(4096) | ||
self.logger.log(logging.INFO, f'use {device}') | ||
|
||
if vocos_config_path: | ||
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() | ||
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' | ||
vocos.load_state_dict(torch.load(vocos_ckpt_path)) | ||
self.pretrain_models['vocos'] = vocos | ||
self.logger.log(logging.INFO, 'vocos loaded.') | ||
|
||
if dvae_config_path: | ||
cfg = OmegaConf.load(dvae_config_path) | ||
dvae = DVAE(**cfg).to(device).eval() | ||
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' | ||
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu')) | ||
self.pretrain_models['dvae'] = dvae | ||
self.logger.log(logging.INFO, 'dvae loaded.') | ||
|
||
if gpt_config_path: | ||
cfg = OmegaConf.load(gpt_config_path) | ||
gpt = GPT_warpper(**cfg).to(device).eval() | ||
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' | ||
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) | ||
self.pretrain_models['gpt'] = gpt | ||
self.logger.log(logging.INFO, 'gpt loaded.') | ||
|
||
if decoder_config_path: | ||
cfg = OmegaConf.load(decoder_config_path) | ||
decoder = DVAE(**cfg).to(device).eval() | ||
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' | ||
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) | ||
self.pretrain_models['decoder'] = decoder | ||
self.logger.log(logging.INFO, 'decoder loaded.') | ||
|
||
if tokenizer_path: | ||
tokenizer = torch.load(tokenizer_path, map_location='cpu') | ||
tokenizer.padding_side = 'left' | ||
self.pretrain_models['tokenizer'] = tokenizer | ||
self.logger.log(logging.INFO, 'tokenizer loaded.') | ||
|
||
self.check_model() | ||
|
||
def infer(self, text, skip_refine_text=False, params_refine_text={}, params_infer_code={}, use_decoder=False): | ||
assert self.check_model(use_decoder=use_decoder) | ||
if not skip_refine_text: | ||
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] | ||
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] | ||
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) | ||
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder) | ||
if use_decoder: | ||
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']] | ||
else: | ||
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']] | ||
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec] | ||
return wav | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
|
||
from openai import OpenAI | ||
|
||
prompt_dict = { | ||
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"}, | ||
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, | ||
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], | ||
'deepseek': [ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"}, | ||
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},], | ||
'deepseek_TN': [ | ||
{"role": "system", "content": "You are a helpful assistant"}, | ||
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"}, | ||
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"}, | ||
{"role": "user", "content": "We paid $123 for this desk."}, | ||
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."}, | ||
{"role": "user", "content": "详询请拨打010-724654"}, | ||
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, | ||
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, | ||
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"}, | ||
], | ||
} | ||
|
||
class llm_api: | ||
def __init__(self, api_key, base_url, model): | ||
self.client = OpenAI( | ||
api_key = api_key, | ||
base_url = base_url, | ||
) | ||
self.model = model | ||
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs): | ||
|
||
completion = self.client.chat.completions.create( | ||
model = self.model, | ||
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},], | ||
temperature = temperature, | ||
**kwargs | ||
) | ||
return completion.choices[0].message.content |
Oops, something went wrong.