Skip to content

Commit

Permalink
add invalid character check
Browse files Browse the repository at this point in the history
  • Loading branch information
lich99 committed May 30, 2024
1 parent b978e88 commit a29fa40
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
7 changes: 7 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.infer_utils import count_invalid_characters
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code

Expand Down Expand Up @@ -134,6 +135,12 @@ def infer(

assert self.check_model(use_decoder=use_decoder)


for i in text:
invalid_characters = count_invalid_characters(i)
if len(invalid_characters):
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')

if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
Expand Down
10 changes: 9 additions & 1 deletion ChatTTS/utils/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import re
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -42,4 +43,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
scores.scatter_(1, input_ids, score)

return scores
return scores

def count_invalid_characters(s):

s = re.sub(r'\[uv_break\]|\[laugh\]|\[lbreak\]', '', s)
pattern = re.compile(r'[^\u4e00-\u9fffA-Za-z,。,\. ]')
non_alphabetic_chinese_chars = pattern.findall(s)
return set(non_alphabetic_chinese_chars)

0 comments on commit a29fa40

Please sign in to comment.