Skip to content

Commit

Permalink
support load from local & return refined text directly
Browse files Browse the repository at this point in the history
  • Loading branch information
lich99 committed May 28, 2024
1 parent af1c0d8 commit 0df145b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
39 changes: 35 additions & 4 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code

from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -39,11 +40,23 @@ def check_model(self, level = logging.INFO, use_decoder = False):

return not not_finish

def load_models(self, source='huggingface'):
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'):
if source == 'huggingface':
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
except:
download_path = None
if download_path is None or force_redownload:
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
else:
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})

elif source == 'local':
self.logger.log(logging.INFO, f'Load from local: {local_path}')
self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})

def _load(
self,
vocos_config_path: str = None,
Expand Down Expand Up @@ -100,18 +113,36 @@ def _load(

self.check_model()

def infer(self, text, skip_refine_text=False, params_refine_text={}, params_infer_code={}, use_decoder=False):
def infer(
self,
text,
skip_refine_text=False,
refine_text_only=False,
params_refine_text={},
params_infer_code={},
use_decoder=False
):

assert self.check_model(use_decoder=use_decoder)

if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
if refine_text_only:
return text

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)

if use_decoder:
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
else:
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]

wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]

return wav


Expand Down
14 changes: 14 additions & 0 deletions ChatTTS/utils/io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

import os
import logging

def get_latest_modified_file(directory):
logger = logging.getLogger(__name__)

files = [os.path.join(directory, f) for f in os.listdir(directory)]
if not files:
logger.log(logging.WARNING, f'No files found in the directory: {directory}')
return None
latest_file = max(files, key=os.path.getmtime)

return latest_file

0 comments on commit 0df145b

Please sign in to comment.