Skip to content

Commit

Permalink
init upload
Browse files Browse the repository at this point in the history
  • Loading branch information
lich99 committed May 27, 2024
0 parents commit 8a0e837
Show file tree
Hide file tree
Showing 12 changed files with 1,327 additions and 0 deletions.
163 changes: 163 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.ckpt
# C extensions
*.so
*.pt

# Distribution / packaging
.Python
outputs/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
asset/*
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
1 change: 1 addition & 0 deletions ChatTTS/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .core import Chat
118 changes: 118 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@

import os
import logging
from omegaconf import OmegaConf

import torch
from vocos import Vocos
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .infer.api import refine_text, infer_code

from huggingface_hub import snapshot_download

logging.basicConfig(level = logging.INFO)


class Chat:
def __init__(self, ):
self.pretrain_models = {}
self.logger = logging.getLogger(__name__)

def check_model(self, level = logging.INFO, use_decoder = False):
not_finish = False
check_list = ['vocos', 'gpt', 'tokenizer']

if use_decoder:
check_list.append('decoder')
else:
check_list.append('dvae')

for module in check_list:
if module not in self.pretrain_models:
self.logger.log(logging.WARNING, f'{module} not initialized.')
not_finish = True

if not not_finish:
self.logger.log(level, f'All initialized.')

return not not_finish

def load_models(self, source='huggingface'):
if source == 'huggingface':
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})

def _load(
self,
vocos_config_path: str = None,
vocos_ckpt_path: str = None,
dvae_config_path: str = None,
dvae_ckpt_path: str = None,
gpt_config_path: str = None,
gpt_ckpt_path: str = None,
decoder_config_path: str = None,
decoder_ckpt_path: str = None,
tokenizer_path: str = None,
device: str = None
):
if not device:
device = select_device(4096)
self.logger.log(logging.INFO, f'use {device}')

if vocos_config_path:
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
vocos.load_state_dict(torch.load(vocos_ckpt_path))
self.pretrain_models['vocos'] = vocos
self.logger.log(logging.INFO, 'vocos loaded.')

if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg).to(device).eval()
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
self.pretrain_models['dvae'] = dvae
self.logger.log(logging.INFO, 'dvae loaded.')

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT_warpper(**cfg).to(device).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
self.pretrain_models['gpt'] = gpt
self.logger.log(logging.INFO, 'gpt loaded.')

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg).to(device).eval()
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
self.pretrain_models['decoder'] = decoder
self.logger.log(logging.INFO, 'decoder loaded.')

if tokenizer_path:
tokenizer = torch.load(tokenizer_path, map_location='cpu')
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')

self.check_model()

def infer(self, text, skip_refine_text=False, params_refine_text={}, params_infer_code={}, use_decoder=False):
assert self.check_model(use_decoder=use_decoder)
if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
if use_decoder:
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
else:
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
return wav



40 changes: 40 additions & 0 deletions ChatTTS/experimental/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

from openai import OpenAI

prompt_dict = {
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
'deepseek': [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
'deepseek_TN': [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
{"role": "user", "content": "We paid $123 for this desk."},
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
{"role": "user", "content": "详询请拨打010-724654"},
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
],
}

class llm_api:
def __init__(self, api_key, base_url, model):
self.client = OpenAI(
api_key = api_key,
base_url = base_url,
)
self.model = model
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):

completion = self.client.chat.completions.create(
model = self.model,
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
temperature = temperature,
**kwargs
)
return completion.choices[0].message.content
Loading

0 comments on commit 8a0e837

Please sign in to comment.