forked from 2noise/ChatTTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
#655.py
96 lines (78 loc) · 2.27 KB
/
#655.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os, sys
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
now_dir = os.getcwd()
sys.path.append(now_dir)
import logging
import torch
import ChatTTS
from tools.logger import get_logger
from tools.normalizer import normalizer_en_nemo_text
logger = get_logger("Test", lv=logging.WARN)
chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except:
logger.warning("Package nemo_text_processing not found!")
rand_spk = chat.sample_random_speaker()
text = ["What is [uv_break]your favorite english food?[laugh][lbreak]"]
fail = False
refined_text = chat.infer(
text,
refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
prompt="[oral_2][laugh_0][break_6]",
manual_seed=12345,
),
)
if (
refined_text[0]
!= "what is [uv_break] your favorite english [uv_break] food [laugh] like [lbreak]"
):
fail = True
logger.warning("refined text is '%s'", refined_text[0])
params = ChatTTS.Chat.InferCodeParams(
spk_emb=rand_spk, # add sampled speaker
temperature=0.3, # using custom temperature
top_P=0.7, # top P decode
top_K=20, # top K decode
)
input_ids, attention_mask, text_mask = chat.tokenizer.encode(
chat.speaker.decorate_code_prompts(
text,
params.prompt,
params.txt_smp,
params.spk_emb,
),
chat.config.gpt.num_vq,
prompt=(
chat.speaker.decode_prompt(params.spk_smp)
if params.spk_smp is not None
else None
),
device=chat.device_gpt,
)
with torch.inference_mode():
start_idx, end_idx = 0, torch.zeros(
input_ids.shape[0], device=input_ids.device, dtype=torch.long
).fill_(input_ids.shape[1])
recoded_text = chat.tokenizer.decode(
chat.gpt._prepare_generation_outputs(
input_ids,
start_idx,
end_idx,
[],
[],
True,
).ids
)
if (
recoded_text[0]
!= "[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]"
):
fail = True
logger.warning("recoded text is '%s'", refined_text)
if fail:
import sys
sys.exit(1)