Skip to content

Commit

Permalink
improve tts selection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 24, 2024
1 parent 0f2621d commit 7c8b0ad
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 101 deletions.
63 changes: 0 additions & 63 deletions agixt/extensions/huggingface.py

This file was deleted.

107 changes: 69 additions & 38 deletions agixt/extensions/voice_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,41 +53,21 @@ def __init__(self, WHISPER_MODEL="base.en", **kwargs):
self.user = kwargs["user"] if "user" in kwargs else DEFAULT_USER
self.tts_command = "Speak with TTS with Streamlabs Text to Speech"
if "USE_STREAMLABS_TTS" in kwargs:
if isinstance(kwargs["USE_STREAMLABS_TTS"], bool):
if kwargs["USE_STREAMLABS_TTS"]:
self.tts_command = "Speak with TTS with Streamlabs Text to Speech"
else:
if kwargs["USE_STREAMLABS_TTS"].lower() == "true":
self.tts_command = "Speak with TTS with Streamlabs Text to Speech"
if str(kwargs["USE_STREAMLABS_TTS"]).lower() == "true":
self.tts_command = "Speak with TTS with Streamlabs Text to Speech"
if "USE_GTTS" in kwargs:
if isinstance(kwargs["USE_GTTS"], bool):
if kwargs["USE_GTTS"]:
self.tts_command = "Speak with GTTS"
else:
if kwargs["USE_GTTS"].lower() == "true":
self.tts_command = "Speak with GTTS"
if "USE_HUGGINGFACE_TTS" in kwargs:
if isinstance(kwargs["USE_HUGGINGFACE_TTS"], bool):
if kwargs["USE_HUGGINGFACE_TTS"] and "HUGGINGFACE_API_KEY" in kwargs:
if kwargs["ELEVENLABS_API_KEY"] != "":
self.tts_command = "Read Audio with Huggingface"
else:
if (
kwargs["USE_HUGGINGFACE_TTS"].lower() == "true"
and "HUGGINGFACE_API_KEY" in kwargs
):
if kwargs["HUGGINGFACE_API_KEY"] != "":
self.tts_command = "Read Audio with Huggingface"
if str(kwargs["USE_GTTS"]).lower() == "true":
self.tts_command = "Speak with GTTS"
if "ELEVENLABS_API_KEY" in kwargs:
if kwargs["ELEVENLABS_API_KEY"] != "":
self.tts_command = "Speak with TTS Using Elevenlabs"
if "USE_ALLTALK_TTS" in kwargs:
if kwargs["USE_ALLTALK_TTS"].lower() == "true":
if kwargs["USE_ALLTALK_TTS"]:
self.tts_command = "Speak with TTS with Alltalk Text to Speech"
if str(kwargs["USE_ALLTALK_TTS"]).lower() == "true":
self.tts_command = "Speak with TTS with Alltalk Text to Speech"

self.commands = {
"Chat with Voice": self.chat_with_voice,
"Command with Voice": self.command_with_voice,
"Transcribe WAV Audio": self.transcribe_wav_audio,
"Transcribe M4A Audio": self.transcribe_m4a_audio,
"Transcribe WEBM Audio": self.transcribe_webm_audio,
Expand Down Expand Up @@ -211,21 +191,41 @@ async def transcribe_webm_audio(
os.remove(os.path.join(os.getcwd(), "WORKSPACE", filename))
return user_input

async def get_wav_audio(
self,
base64_audio,
audio_format="m4a",
):
filename = f"{uuid.uuid4().hex}.wav"
if audio_format.lower() == "webm":
user_audio = await self.convert_webm_to_wav(
base64_audio=base64_audio, filename=filename
)
elif audio_format.lower() == "m4a":
user_audio = await self.convert_m4a_to_wav(
base64_audio=base64_audio, filename=filename
)
else:
user_audio = base64_audio
return user_audio

async def chat_with_voice(
self,
base64_audio,
context_results=10,
tts=False,
inject_memories_from_collection_number=0,
audio_format="m4a",
prompt_name="Custom Input",
prompt_args={
"context_results": 6,
"inject_memories_from_collection_number": 0,
},
):
# Convert from M4A to WAV
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.convert_m4a_to_wav(
base64_audio=base64_audio, filename=filename
user_audio = await self.get_wav_audio(
base64_audio=base64_audio, audio_format=audio_format, filename=filename
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
prompt_args["user_input"] = user_input
user_message = f"{user_input}\n#GENERATED_AUDIO:{user_audio}"
log_interaction(
agent_name=self.agent_name,
Expand All @@ -239,11 +239,42 @@ async def chat_with_voice(
text_response = self.ApiClient.prompt_agent(
agent_name=self.agent_name,
prompt_name=prompt_name,
prompt_args={
"user_input": user_input,
"context_results": context_results,
"inject_memories_from_collection_number": inject_memories_from_collection_number,
},
prompt_args=prompt_args,
)
logging.info(f"[Whisper]: Text Response from LLM: {text_response}")
return self.text_to_speech(text=text_response)

async def command_with_voice(
self,
base64_audio,
audio_format="m4a",
audio_variable="data_to_correlate_with_input",
command_name="Store information in my long term memory",
command_args={"input": "Voice transcription from user"},
tts=False,
):
filename = f"{uuid.uuid4().hex}.wav"
user_audio = await self.get_wav_audio(
base64_audio=base64_audio, audio_format=audio_format, filename=filename
)
# Transcribe the audio to text.
user_input = await self.transcribe_audio_from_file(filename=filename)
command_args[audio_variable] = user_input
user_message = f"{user_input}\n#GENERATED_AUDIO:{user_audio}"
log_interaction(
agent_name=self.agent_name,
conversation_name=self.conversation_name,
role="USER",
message=user_message,
user=self.user,
)
logging.info(f"[Whisper]: Transcribed User Input: {user_input}")
# Send the transcribed text to the agent.
text_response = self.ApiClient.execute_command(
agent_name=self.agent_name,
command_name=command_name,
command_args=command_args,
conversation_name="AGiXT Terminal",
)
logging.info(f"[Whisper]: Text Response from LLM: {text_response}")
if str(tts).lower() == "true":
Expand Down

0 comments on commit 7c8b0ad

Please sign in to comment.