Skip to content

Commit

Permalink
Add --no-tts optional flag to disable XTTS
Browse files Browse the repository at this point in the history
  • Loading branch information
Elbios committed Jan 31, 2024
1 parent 150c4ac commit 07e38bd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 45 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Download Automatic1111 here:
Rough notes - I will polish the README later:

### XTTS
- Current revision uses XTTS2 (uses the TTS Python library, lookup on coqui.ai) and it is REQUIRED (will make it as an option later)
- Current revision uses XTTS2 (uses the TTS Python library, lookup on coqui.ai)
- if you do not want to use TTS pass --no-tts like so: 'python bot.py --no-tts'
- XTTS2 uses some RAM/VRAM so bear that in mind
- for setup I used windows gpu steps from https://github.com/daswer123/xtts-api-server (no need to clone xtts-api-server, just do the install steps from its README)
- definitely need 'pip install pillow' and 'pip install TTS'
Expand Down Expand Up @@ -78,6 +79,7 @@ To run this bot:
4. Install the requirements. I suggest using an Anaconda or Miniconda instance.
```pip install -r requirements.txt```
5. Run the bot with `python bot.py`
- optionally with --no-tts flag

Cheers!

Expand Down
116 changes: 72 additions & 44 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
import re
import unittest
import logging
import argparse

from typing import List
from pydub import AudioSegment
from PIL import Image

# For xtts2 TTS
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# For xtts2 TTS (now imported conditionally at the bottom of the script)
#import torch
#import torchaudio
#from TTS.tts.configs.xtts_config import XttsConfig
#from TTS.tts.models.xtts import Xtts

from aiohttp import ClientSession
from aiohttp import ClientTimeout
Expand All @@ -34,6 +36,9 @@
# Your API keys and tokens go here. Do not commit with these in place!
discord_api_key = "PUT_YOUR_API_KEY_HERE"

# Flag to disable TTS (XTTS) - use --no-tts when launching
disable_tts_global = False

intents = discord.Intents.all()
intents.message_content = True

Expand Down Expand Up @@ -454,6 +459,9 @@ async def send_to_stable_diffusion_queue():

# New function to handle TTS generation
async def generate_tts(text):
if 'Xtts' not in globals():
return None

await functions.write_to_log("Generating TTS for the given text...")
out = tts_model.inference(
text,
Expand Down Expand Up @@ -593,35 +601,39 @@ async def send_to_user_queue():

# Do not do TTS for image gen responses
if not reply["content"]["image"]:
try:
# After getting the dialogue, split it
dialogue_parts = split_dialogue(reply["response"], 200)

# Generate TTS audio for each part
audio_parts = []
for part in dialogue_parts:
# better not send emojis to TTS
part = strip_emoji(part)
audio_path = await generate_tts(part)
audio_parts.append(AudioSegment.from_wav(audio_path))
os.remove(audio_path)

# Create a silent audio segment of 0.5 seconds (500 milliseconds)
silence = AudioSegment.silent(duration=800)

# Add the silent segment between each pair of audio segments
combined_audio = sum(x for y in zip(audio_parts, [silence]*len(audio_parts)) for x in y)

# Save the combined audio file
md5hash = hashlib.md5(reply["response"].encode('utf-8'))
md5hash_hex = md5hash.hexdigest()
combined_audio_path = "tts_output_" + md5hash_hex + ".wav"
combined_audio.export(combined_audio_path, format="wav")

# Send the combined audio file
audio_file = discord.File(combined_audio_path)
except Exception as e:
# TTS generation failed, skip TTS audio
if 'Xtts' in globals():
try:
# After getting the dialogue, split it
dialogue_parts = split_dialogue(reply["response"], 200)

# Generate TTS audio for each part
audio_parts = []
for part in dialogue_parts:
# better not send emojis to TTS
part = strip_emoji(part)
audio_path = await generate_tts(part)
audio_parts.append(AudioSegment.from_wav(audio_path))
os.remove(audio_path)

# Create a silent audio segment of 0.5 seconds (500 milliseconds)
silence = AudioSegment.silent(duration=800)

# Add the silent segment between each pair of audio segments
combined_audio = sum(x for y in zip(audio_parts, [silence]*len(audio_parts)) for x in y)

# Save the combined audio file
md5hash = hashlib.md5(reply["response"].encode('utf-8'))
md5hash_hex = md5hash.hexdigest()
combined_audio_path = "tts_output_" + md5hash_hex + ".wav"
combined_audio.export(combined_audio_path, format="wav")

# Send the combined audio file
audio_file = discord.File(combined_audio_path)
except Exception as e:
# TTS generation failed, skip TTS audio
audio_file = None
else:
# TTS disabled, skip TTS audio
audio_file = None

if not reply["content"]["channel"]:
Expand Down Expand Up @@ -654,16 +666,17 @@ async def on_ready():
logging.basicConfig(level=logging.DEBUG)

# Load TTS model
await functions.write_to_log("Loading TTS model...")
tts_config = XttsConfig()
tts_config.load_json("xtts/config.json")
tts_model = Xtts.init_from_config(tts_config)
tts_model.load_checkpoint(tts_config, checkpoint_dir="./xtts/model", use_deepspeed=False)
tts_model.cuda()

# Compute speaker latents
await functions.write_to_log("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = tts_model.get_conditioning_latents(audio_path=[r"xtts/scarlett24000.wav"])
if 'Xtts' in globals():
await functions.write_to_log("Loading TTS model...")
tts_config = XttsConfig()
tts_config.load_json("xtts/config.json")
tts_model = Xtts.init_from_config(tts_config)
tts_model.load_checkpoint(tts_config, checkpoint_dir="./xtts/model", use_deepspeed=False)
tts_model.cuda()

# Compute speaker latents
await functions.write_to_log("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = tts_model.get_conditioning_latents(audio_path=[r"xtts/scarlett24000.wav"])

text_api = await functions.set_api("text-default.json")
image_api = await functions.set_api("image-default.json")
Expand Down Expand Up @@ -862,6 +875,21 @@ async def parameter_select_callback(interaction):
# Let the user know that their request has been completed
await interaction.followup.send(interaction.user.name + " updated the bot's sampler parameters. " + api_check)

parser = argparse.ArgumentParser(description='AI Discord bot, use --no-tts to disable XTTS and save some VRAM')
parser.add_argument('--no-tts', action='store_true', help='Flag to disable TTS (XTTS)')
args = parser.parse_args()
disable_tts_global = args.no_tts

if not disable_tts_global:
print("Running with XTTS (will eat some VRAM)")
print("Loading XTTS libraries...")
import torch
import torchaudio
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
print("XTTS libraries imported.")
else:
print("Running without XTTS - no speech audio will be generated")

client.run(discord_api_key)
#unittest.main()

0 comments on commit 07e38bd

Please sign in to comment.