Skip to content

Commit

Permalink
add text normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
lich99 committed May 30, 2024
1 parent a29fa40 commit 7a323cd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
19 changes: 16 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.infer_utils import count_invalid_characters
from .utils.infer_utils import count_invalid_characters, detect_language
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code

Expand All @@ -22,6 +22,7 @@
class Chat:
def __init__(self, ):
self.pretrain_models = {}
self.normalizer = {}
self.logger = logging.getLogger(__name__)

def check_model(self, level = logging.INFO, use_decoder = False):
Expand Down Expand Up @@ -130,12 +131,19 @@ def infer(
refine_text_only=False,
params_refine_text={},
params_infer_code={},
use_decoder=False
use_decoder=True,
do_text_normalization=True,
lang=None,
):

assert self.check_model(use_decoder=use_decoder)


if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
self.init_normalizer(_lang)
text[i] = self.normalizer[_lang].normalize(t, verbose=False, punct_post_process=True)

for i in text:
invalid_characters = count_invalid_characters(i)
if len(invalid_characters):
Expand Down Expand Up @@ -166,5 +174,10 @@ def sample_random_speaker(self, ):
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models['spk_stat'].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

def init_normalizer(self, lang):

if lang not in self.normalizer:
from nemo_text_processing.text_normalization.normalize import Normalizer
self.normalizer[lang] = Normalizer(input_case='cased', lang=lang)

15 changes: 14 additions & 1 deletion ChatTTS/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,17 @@ def count_invalid_characters(s):
s = re.sub(r'\[uv_break\]|\[laugh\]|\[lbreak\]', '', s)
pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。,\. ]')
non_alphabetic_chinese_chars = pattern.findall(s)
return set(non_alphabetic_chinese_chars)
return set(non_alphabetic_chinese_chars)

def detect_language(sentence):

chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
english_word_pattern = re.compile(r'\b[A-Za-z]+\b')

chinese_chars = chinese_char_pattern.findall(sentence)
english_words = english_word_pattern.findall(sentence)

if len(chinese_chars) > len(english_words):
return "zh"
else:
return "en"

0 comments on commit 7a323cd

Please sign in to comment.