forked from tufo830/virtual_human_stream
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ttsreal.py
419 lines (354 loc) · 15.5 KB
/
ttsreal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import time
import numpy as np
import soundfile as sf
import resampy
import asyncio
import edge_tts
from typing import AsyncIterator
import aiohttp
from io import BytesIO
from enum import Enum
class State(Enum):
"""
Enumeration to represent the state of the TTS processing.
Attributes:
RUNNING (int): Indicates that TTS processing is active.
PAUSE (int): Indicates that TTS processing is paused.
"""
RUNNING = 0
PAUSE = 1
class BaseTTS:
"""
BaseTTS is an abstract base class for implementing different TTS (Text-to-Speech) engines.
It manages the message queue, state, and provides a framework for rendering audio from text.
Attributes:
opt: Configuration options for TTS.
parent: Reference to the parent object that handles audio frames.
fps (int): Frames per second, determining the number of samples per chunk.
sample_rate (int): Target sample rate for audio processing.
chunk (int): Number of samples per audio chunk.
input_stream (BytesIO): Buffer to store incoming audio data.
msgqueue (asyncio.Queue): Queue to hold incoming text messages for TTS processing.
state (State): Current state of TTS processing (RUNNING or PAUSE).
"""
def __init__(self, opt, parent):
"""
Initializes the BaseTTS instance with configuration and parent references.
Args:
opt: Configuration options for TTS.
parent: Reference to the parent object that handles audio frames.
"""
self.opt = opt
self.parent = parent
self.fps = opt.fps # Frames per second (e.g., 50 for 20ms per frame)
self.sample_rate = 16000 # Target sample rate for audio
self.chunk = self.sample_rate // self.fps # Samples per chunk (e.g., 320 for 20ms)
self.input_stream = BytesIO()
self.msgqueue = asyncio.Queue()
self.state = State.RUNNING
def pause_talk(self):
"""
Pauses the TTS processing by clearing the message queue and updating the state.
"""
while not self.msgqueue.empty():
self.msgqueue.get_nowait()
self.state = State.PAUSE
def put_msg_txt(self, msg: str):
"""
Adds a text message to the TTS processing queue.
Args:
msg (str): The text message to be converted to speech.
"""
self.msgqueue.put_nowait(msg)
def render(self, quit_event: asyncio.Event):
"""
Starts the TTS processing coroutine as an asyncio Task.
Args:
quit_event (asyncio.Event): Event to signal the coroutine to stop processing.
"""
asyncio.create_task(self.process_tts(quit_event))
async def process_tts(self, quit_event: asyncio.Event):
"""
Coroutine that continuously processes text messages from the queue and converts them to audio.
Args:
quit_event (asyncio.Event): Event to signal the coroutine to stop processing.
"""
while not quit_event.is_set():
try:
msg = await asyncio.wait_for(self.msgqueue.get(), timeout=1)
self.state = State.RUNNING
await self.txt_to_audio(msg)
except asyncio.TimeoutError:
continue
print('BaseTTS task stopped')
async def txt_to_audio(self, msg: str):
"""
Abstract method to convert text to audio. Must be implemented by subclasses.
Args:
msg (str): The text message to convert to audio.
"""
raise NotImplementedError("Subclasses must implement txt_to_audio method")
###########################################################################################
class EdgeTTS(BaseTTS):
"""
EdgeTTS is a subclass of BaseTTS that utilizes the Edge TTS engine to convert text to audio.
Attributes:
None additional beyond BaseTTS.
"""
async def txt_to_audio(self, msg: str):
"""
Converts text to audio using the Edge TTS engine and streams the audio in chunks.
Args:
msg (str): The text message to convert to audio.
"""
voicename = "zh-CN-YunxiaNeural"
start_time = time.time()
await self.__main(voicename, msg)
print(f'-------EdgeTTS processing time: {time.time() - start_time:.4f}s')
await self.__stream_audio_chunks()
async def __stream_audio_chunks(self):
"""
Streams the audio from the input stream in chunks to the parent handler.
"""
self.input_stream.seek(0)
stream = await self.__create_bytes_stream(self.input_stream)
streamlen = stream.shape[0]
idx = 0
while streamlen >= self.chunk and self.state == State.RUNNING:
self.parent.put_audio_frame(stream[idx:idx + self.chunk])
streamlen -= self.chunk
idx += self.chunk
self.input_stream.seek(0)
self.input_stream.truncate()
async def __create_bytes_stream(self, byte_stream: BytesIO) -> np.ndarray:
"""
Reads and processes the byte stream into a normalized numpy array suitable for streaming.
Args:
byte_stream (BytesIO): The byte stream containing audio data.
Returns:
np.ndarray: The processed audio stream.
"""
loop = asyncio.get_event_loop()
stream, sample_rate = await loop.run_in_executor(None, sf.read, byte_stream)
print(f'[INFO] TTS audio stream sample rate: {sample_rate}, shape: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] Audio has {stream.shape[1]} channels, only using the first channel.')
stream = stream[:, 0]
if sample_rate != self.sample_rate and stream.shape[0] > 0:
print(f'[WARN] Audio sample rate is {sample_rate}, resampling to {self.sample_rate}.')
stream = await loop.run_in_executor(None, resampy.resample, stream, sample_rate, self.sample_rate)
return stream
async def __main(self, voicename: str, text: str):
"""
Main method to interact with the Edge TTS engine and stream audio data.
Args:
voicename (str): The name of the voice to use for TTS.
text (str): The text message to convert to audio.
"""
communicate = edge_tts.Communicate(text, voicename)
first_chunk = True
async for chunk in communicate.stream():
if chunk["type"] == "audio" and self.state == State.RUNNING:
self.input_stream.write(chunk["data"])
###########################################################################################
class VoitsTTS(BaseTTS):
"""
VoitsTTS is a subclass of BaseTTS that utilizes the Voits TTS engine for converting text to audio.
Attributes:
None additional beyond BaseTTS.
"""
async def txt_to_audio(self, msg: str):
"""
Converts text to audio using the Voits TTS engine and streams the audio in chunks.
Args:
msg (str): The text message to convert to audio.
"""
audio_stream = self.gpt_sovits(
msg,
self.opt.REF_FILE,
self.opt.REF_TEXT,
"zh",
self.opt.TTS_SERVER,
)
await self.stream_tts(audio_stream)
async def gpt_sovits(self, text: str, reffile: str, reftext: str, language: str, server_url: str) -> AsyncIterator[bytes]:
"""
Sends a request to the Voits TTS server and yields audio chunks as they are received.
Args:
text (str): The text to convert to speech.
reffile (str): Path to the reference audio file.
reftext (str): Reference text for speaker cloning.
language (str): Language code (e.g., "zh").
server_url (str): URL of the Voits TTS server.
Yields:
bytes: Audio chunks received from the server.
"""
start = time.perf_counter()
req = {
'text': text,
'text_lang': language,
'ref_audio_path': reffile,
'prompt_text': reftext,
'prompt_lang': language,
'media_type': 'raw',
'streaming_mode': True
}
async with aiohttp.ClientSession() as session:
async with session.post(f"{server_url}/tts", json=req) as res:
end = time.perf_counter()
print(f"VoitsTTS: Time to make POST request: {end - start:.4f}s")
if res.status != 200:
error_text = await res.text()
print(f"VoitsTTS Error: {error_text}")
return
first = True
async for chunk in res.content.iter_chunked(32000):
if first:
end = time.perf_counter()
print(f"VoitsTTS: Time to first chunk: {end - start:.4f}s")
first = False
if chunk and self.state == State.RUNNING:
yield chunk
print(f"VoitsTTS response elapsed time: {res.headers.get('X-Response-Time')}")
async def stream_tts(self, audio_stream: AsyncIterator[bytes]):
"""
Streams the audio chunks by resampling and sending them to the parent handler.
Args:
audio_stream (AsyncIterator[bytes]): Asynchronous iterator of audio chunks.
"""
async for chunk in audio_stream:
if chunk and len(chunk) > 0:
await self.__process_and_stream_chunk(chunk, 32000)
async def __process_and_stream_chunk(self, chunk: bytes, original_sample_rate: int):
"""
Processes a single audio chunk and sends it to the parent handler.
Args:
chunk (bytes): Audio chunk to process.
original_sample_rate (int): Original sample rate of the audio chunk.
"""
loop = asyncio.get_event_loop()
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = await loop.run_in_executor(None, resampy.resample, stream, original_sample_rate, self.sample_rate)
streamlen = stream.shape[0]
idx = 0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx + self.chunk])
streamlen -= self.chunk
idx += self.chunk
###########################################################################################
class XTTS(BaseTTS):
"""
XTTS is a subclass of BaseTTS that utilizes the XTTS engine for converting text to audio.
Attributes:
speaker (dict): Configuration for the speaker, obtained from the TTS server.
"""
def __init__(self, opt, parent):
"""
Initializes the XTTS instance and sets the speaker configuration to None.
Args:
opt: Configuration options for TTS.
parent: Reference to the parent object that handles audio frames.
"""
super().__init__(opt, parent)
self.speaker = None
async def txt_to_audio(self, msg: str):
"""
Converts text to audio using the XTTS engine and streams the audio in chunks.
Args:
msg (str): The text message to convert to audio.
"""
if not self.speaker:
self.speaker = await self.get_speaker(self.opt.REF_FILE, self.opt.TTS_SERVER)
audio_stream = self.xtts(
msg,
self.speaker,
"zh-cn",
self.opt.TTS_SERVER,
"20"
)
await self.stream_tts(audio_stream)
async def get_speaker(self, ref_audio: str, server_url: str) -> dict:
"""
Obtains the speaker configuration by sending a reference audio file to the XTTS server.
Args:
ref_audio (str): Path to the reference audio file.
server_url (str): URL of the XTTS server.
Returns:
dict: Speaker configuration obtained from the server.
"""
data = aiohttp.FormData()
try:
with open(ref_audio, 'rb') as f:
data.add_field('wav_file', f, filename='reference.wav')
async with aiohttp.ClientSession() as session:
async with session.post(f"{server_url}/clone_speaker", data=data) as response:
if response.status != 200:
error_text = await response.text()
print(f"XTTS Error: {error_text}")
return {}
return await response.json()
except FileNotFoundError:
print(f"XTTS Error: Reference audio file '{ref_audio}' not found.")
return {}
except Exception as e:
print(f"XTTS Exception: {e}")
return {}
async def xtts(self, text: str, speaker: dict, language: str, server_url: str, stream_chunk_size: str) -> AsyncIterator[bytes]:
"""
Sends a request to the XTTS server and yields audio chunks as they are received.
Args:
text (str): The text to convert to speech.
speaker (dict): Speaker configuration.
language (str): Language code (e.g., "zh-cn").
server_url (str): URL of the XTTS server.
stream_chunk_size (str): Size of each audio stream chunk.
Yields:
bytes: Audio chunks received from the server.
"""
start = time.perf_counter()
speaker["text"] = text
speaker["language"] = language
speaker["stream_chunk_size"] = stream_chunk_size
async with aiohttp.ClientSession() as session:
async with session.post(f"{server_url}/tts_stream", json=speaker) as res:
end = time.perf_counter()
print(f"XTTS: Time to make POST request: {end - start:.4f}s")
if res.status != 200:
error_text = await res.text()
print(f"XTTS Error: {error_text}")
return
first = True
async for chunk in res.content.iter_chunked(960):
if first:
end = time.perf_counter()
print(f"XTTS: Time to first chunk: {end - start:.4f}s")
first = False
if chunk:
yield chunk
print(f"XTTS response elapsed time: {res.headers.get('X-Response-Time')}")
async def stream_tts(self, audio_stream: AsyncIterator[bytes]):
"""
Streams the audio chunks by resampling and sending them to the parent handler.
Args:
audio_stream (AsyncIterator[bytes]): Asynchronous iterator of audio chunks.
"""
async for chunk in audio_stream:
if chunk and len(chunk) > 0:
await self.__process_and_stream_chunk(chunk, 24000)
async def __process_and_stream_chunk(self, chunk: bytes, original_sample_rate: int):
"""
Processes a single audio chunk and sends it to the parent handler.
Args:
chunk (bytes): Audio chunk to process.
original_sample_rate (int): Original sample rate of the audio chunk.
"""
loop = asyncio.get_event_loop()
stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767
stream = await loop.run_in_executor(None, resampy.resample, stream, original_sample_rate, self.sample_rate)
streamlen = stream.shape[0]
idx = 0
while streamlen >= self.chunk:
self.parent.put_audio_frame(stream[idx:idx + self.chunk])
streamlen -= self.chunk
idx += self.chunk