Skip to content

Commit

Permalink
add TN
Browse files Browse the repository at this point in the history
  • Loading branch information
lich99 committed May 31, 2024
1 parent 2b872fb commit a80439e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 8 deletions.
30 changes: 23 additions & 7 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

import os
import logging
from functools import partial
from omegaconf import OmegaConf

import torch
from vocos import Vocos
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.infer_utils import count_invalid_characters, detect_language
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code

Expand Down Expand Up @@ -130,7 +131,7 @@ def infer(
params_refine_text={},
params_infer_code={'prompt':'[speed_5]'},
use_decoder=True,
do_text_normalization=False,
do_text_normalization=True,
lang=None,
):

Expand All @@ -143,12 +144,15 @@ def infer(
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
self.init_normalizer(_lang)
text[i] = self.normalizer[_lang].normalize(t, verbose=False, punct_post_process=True)
text[i] = self.normalizer[_lang](t)
if _lang == 'zh':
text[i] = apply_half2full_map(text[i])

for i in text:
invalid_characters = count_invalid_characters(i)
for i, t in enumerate(text):
invalid_characters = count_invalid_characters(t)
if len(invalid_characters):
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
text[i] = apply_character_map(t)

if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
Expand Down Expand Up @@ -179,6 +183,18 @@ def sample_random_speaker(self, ):
def init_normalizer(self, lang):

if lang not in self.normalizer:
from nemo_text_processing.text_normalization.normalize import Normalizer
self.normalizer[lang] = Normalizer(input_case='cased', lang=lang)
if lang == 'zh':
try:
from tn.chinese.normalizer import Normalizer
except:
self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
self.normalizer[lang] = Normalizer().normalize
else:
try:
from nemo_text_processing.text_normalization.normalize import Normalizer
except:
self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)

77 changes: 76 additions & 1 deletion ChatTTS/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,79 @@ def detect_language(sentence):
if len(chinese_chars) > len(english_words):
return "zh"
else:
return "en"
return "en"


character_map = {
':': ',',
';': ',',
'!': '。',
'(': ',',
')': ',',
'【': ',',
'】': ',',
'『': ',',
'』': ',',
'「': ',',
'」': ',',
'《': ',',
'》': ',',
'-': ',',
'‘': '',
'“': '',
'’': '',
'”': '',
':': ',',
';': ',',
'!': '.',
'(': ',',
')': ',',
'[': ',',
']': ',',
'>': ',',
'<': ',',
'-': ',',
}

halfwidth_2_fullwidth_map = {
'!': '!',
'"': '“',
"'": '‘',
'#': '#',
'$': '$',
'%': '%',
'&': '&',
'(': '(',
')': ')',
',': ',',
'-': '-',
'*': '*',
'+': '+',
'.': '。',
'/': '/',
':': ':',
';': ';',
'<': '<',
'=': '=',
'>': '>',
'?': '?',
'@': '@',
'[': '[',
'\\': '\',
']': ']',
'^': '^',
'_': '_',
'`': '`',
'{': '{',
'|': '|',
'}': '}',
'~': '~'
}

def apply_half2full_map(text):
translation_table = str.maketrans(halfwidth_2_fullwidth_map)
return text.translate(translation_table)

def apply_character_map(text):
translation_table = str.maketrans(character_map)
return text.translate(translation_table)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ einops
vector_quantize_pytorch
transformers~=4.41.1
vocos
IPython

0 comments on commit a80439e

Please sign in to comment.