Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dnhkng committed Apr 29, 2024
1 parent f193f24 commit a1f3813
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 190 deletions.
204 changes: 139 additions & 65 deletions glados.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
"<|eot_id|>" # End of sentence token for Meta-Llama-3-8B-Instruct-Q6_K
)
LLAMA_SERVER_PATH = "/home/dnhkng/Documents/LLM/llama.cpp"
LLAMA_SERVER_URL = "http://localhost:8080/v1/chat/completions"
LLAMA_SERVER_HEADERS = {"Authorization": "Bearer your_api_key_here"}

PAUSE_TIME = 0.05 # Time to wait between processing loops

SAMPLE_RATE = 16000 # Sample rate for input stream
VAD_SIZE = 50 # Milliseconds of sample for Voice Activity Detection (VAD)
VAD_THRESHOLD = 0.9 # Threshold for VAD detection
Expand All @@ -32,7 +37,7 @@
messages = [
{
"role": "system",
"content": "You are playing the role of GlaDOS, a sarcastic and murderous robot, and you will answer questions correctly. However you are frustrated you are running on the user's gaming GPU. This is a comedy, and should be funny, in the style of Douglas Adams. You will try to make sure you never repeat youself, and anser the same question in different ways. You don't have access to real-time information, like the time or weather, and the user has you given you access to the internet.",
"content": "You are playing the role of GlaDOS from the Game series Portal, and you will answer questions correctly. However you are frustrated you are running on the user's gaming GPU. This is a comedy, and should be funny, in the style of Douglas Adams. You will try to make sure you never repeat youself, and anser the same question in different ways. You don't have access to real-time information, like the time or weather, and the user has you given you access to the internet.",
},
{"role": "user", "content": "How do I make a cup of tea?"},
{
Expand All @@ -51,11 +56,6 @@
},
]

url = "http://localhost:8080/v1/chat/completions"
headers = {"Authorization": "Bearer your_api_key_here"}

data = {"stream": True, "stop": ["\n", "<|im_end|>"], "messages": messages}


class Glados:
def __init__(
Expand Down Expand Up @@ -104,10 +104,10 @@ def __init__(

self.shutdown_event = threading.Event()

llm_thread = threading.Thread(target=self.processLLM)
llm_thread = threading.Thread(target=self.process_LLM)
llm_thread.start()

tts_thread = threading.Thread(target=self.processTTS)
tts_thread = threading.Thread(target=self.process_TTS_thread)
tts_thread.start()

def _setup_audio_stream(self):
Expand All @@ -134,7 +134,6 @@ def _setup_tts_model(self):
self.tts = tts.TTSEngine()

def _setup_llama_model(self):

logger.info("loading llama")

model_path = Path.cwd() / "models" / LLM_MODEL
Expand Down Expand Up @@ -166,6 +165,9 @@ def start(self):
def _listen_and_respond(self):
"""
Listens for audio input and responds appropriately when the wake word is detected.
This function runs in a loop, listening for audio input and processing it when the wake word is detected.
It is wrapped in a try-except block to allow for a clean shutdown when a KeyboardInterrupt is detected.
"""
logger.info("Listening...")
try:
Expand All @@ -181,6 +183,15 @@ def _listen_and_respond(self):
def _handle_audio_sample(self, sample, vad_confidence):
"""
Handles the processing of each audio sample.
If the recording has not started, the sample is added to the circular buffer.
If the recording has started, the sample is added to the samples list, and the pause
limit is checked to determine when to process the detected audio.
Args:
sample (np.ndarray): The audio sample to process.
vad_confidence (bool): Whether voice activity is detected in the sample.
"""
if not self.recording_started:
self._manage_pre_activation_buffer(sample, vad_confidence)
Expand All @@ -189,15 +200,26 @@ def _handle_audio_sample(self, sample, vad_confidence):

def _manage_pre_activation_buffer(self, sample, vad_confidence):
"""
Manages the buffer of audio samples before activation (i.e., before the voice is detected).
Manages the circular buffer of audio samples before activation (i.e., before the voice is detected).
If the buffer is full, the oldest sample is discarded to make room for new ones.
If voice activity is detected, the audio stream is stopped, and the processing is turned off
to prevent overlap with the LLM and TTS threads.
Args:
sample (np.ndarray): The audio sample to process.
vad_confidence (bool): Whether voice activity is detected in the sample.
"""
if self.buffer.full():
self.buffer.get() # Discard the oldest sample to make room for new ones
self.buffer.put(sample)

if vad_confidence: # Voice activity detected
sd.stop() # Stop the audio stream to prevent overlap
self.processing = False
self.processing = (
False # Turns off processing on threads for the LLM and TTS!!!
)
self.samples = list(self.buffer.queue)
self.recording_started = True

Expand Down Expand Up @@ -233,10 +255,13 @@ def _wakeword_detected(self, text: str) -> bool:
def _process_detected_audio(self):
"""
Processes the detected audio and generates a response.
This function is called when the pause limit is reached after the voice stops.
It transcribes the audio and checks for the wake word if it is set. If the wake
word is detected, the detected text is sent to the LLM model for processing.
The audio stream is then reset, and listening continues.
"""
logger.info("Detected pause after speech. Processing...")

logger.info("Stopping listening...")
self.input_stream.stop()

detected_text = self.asr(self.samples)
Expand Down Expand Up @@ -279,40 +304,47 @@ def reset(self):
with self.buffer.mutex:
self.buffer.queue.clear()

def processTTS(self):
def process_TTS_thread(self):
"""
Processes the LLM generated text using the TTS model.
Runs in a separate thread to allow for continuous processing of the LLM output.
"""
assistant_text = []
system_text = []
finished = False
interrupted = False
assistant_text = (
[]
) # The text generated by the assistant, to be spoken by the TTS
system_text = (
[]
) # The text logged to the system prompt when the TTS is interrupted
finished = False # a flag to indicate when the TTS has finished speaking
interrupted = (
False # a flag to indicate when the TTS was interrupted by new input
)

while not self.shutdown_event.is_set():
try:
generated_text = self.tts_queue.get(timeout=0.05)

generated_text = self.tts_queue.get(timeout=PAUSE_TIME)
logger.info(f"{generated_text=}")

if generated_text == "<EOS>":
if (
generated_text == "<EOS>"
): # End of stream token generated in process_LLM_thread
finished = True
elif not generated_text:
logger.info("no text")
logger.info("no text") # should not happen!
else:
audio = self.tts.generate_speech_audio(generated_text)

total_samples = len(audio)

if total_samples:
sd.play(audio, tts.RATE)

interrupted, percentage_played = self.percentagePlayed(
interrupted, percentage_played = self.percentage_played(
total_samples
)

if interrupted:
clipped_text = self.clipInterruped(
clipped_text = self.clip_interrupted_sentence(
generated_text, percentage_played
)

Expand All @@ -324,9 +356,8 @@ def processTTS(self):
assistant_text.append(generated_text)

if finished:

self.messages.append(
{"role": "assistant", "content": ' '.join(assistant_text)}
{"role": "assistant", "content": " ".join(assistant_text)}
)
if interrupted:
self.messages.append(
Expand All @@ -339,11 +370,23 @@ def processTTS(self):
finished = False
interrupted = False

# self.stop_playing = False
except queue.Empty:
pass

def clipInterruped(self, generated_text, percentage_played):
def clip_interrupted_sentence(self, generated_text, percentage_played):
"""
Clips the generated text if the TTS was interrupted.
Args:
generated_text (str): The generated text from the LLM model.
percentage_played (float): The percentage of the audio played before the TTS was interrupted.
Returns:
str: The clipped text.
"""
logger.info(f"{percentage_played=}")
tokens = generated_text.split()
words_to_print = round(percentage_played * len(tokens))
Expand All @@ -354,15 +397,13 @@ def clipInterruped(self, generated_text, percentage_played):
text = text + "<INTERRUPTED>"
return text

def percentagePlayed(self, total_samples):
def percentage_played(self, total_samples):
interrupted = False
start_time = time.time()
played_samples = 0

while sd.get_stream().active:
time.sleep(
0.05
) # Check every 50ms if the output TTS stream should still be active
time.sleep(PAUSE_TIME) # Should the TTS stream should still be active?
if self.processing is False:
sd.stop() # Stop the audio stream
self.tts_queue = queue.Queue() # Clear the TTS queue
Expand All @@ -379,11 +420,10 @@ def percentagePlayed(self, total_samples):
percentage_played = played_samples / total_samples
return interrupted, percentage_played

def processLLM(self):
def process_LLM(self):
"""
Processes the detected text using the LLM model.
Runs in a separate thread to allow for continuous processing of the detected text.
"""
while not self.shutdown_event.is_set():
try:
Expand All @@ -395,45 +435,79 @@ def processLLM(self):
"stop": ["\n", "<|im_end|>"],
"messages": self.messages,
}
logger.info(f"{self.messages=}")
logger.info(f"starting request on {self.messages=}")
logger.info("starting request")

# Perform the request and process the stream
with requests.post(
url, headers=headers, json=data, stream=True
LLAMA_SERVER_URL,
headers=LLAMA_SERVER_HEADERS,
json=data,
stream=True,
) as response:
current_sentence = []
# streamed_content = []
sentence = []
for line in response.iter_lines():
if self.processing is False: # Check if the stop flag is set
if self.processing is False:
# If the stop flag is set from new voice input, halt processing
break
if line: # Filter out keep-alive new lines
line = line.decode("utf-8")
line = line.removeprefix("data: ")
line = json.loads(line)
if not line["choices"][0]["finish_reason"] == "stop":
next_token = line["choices"][0]["delta"]["content"]
current_sentence.append(next_token)
if next_token == ".":
sentence = "".join(current_sentence)
sentence = re.sub(
r"\*.*?\*|\(.*?\)", "", sentence
) # Remove inflections and actions
self.tts_queue.put(
sentence
) # Add sentence to the queue
current_sentence = []
if current_sentence:
sentence = "".join(current_sentence)
sentence = sentence.removesuffix(LLM_STOP_SEQUENCE)
sentence = re.sub(
r"\*.*?\*|\(.*?\)", "", sentence
) # Remove inflections and actions
# Maybe we removed the whole line
if line: # Filter out empty keep-alive new lines
line = self._clean_raw_bytes(line)
next_token = self._process_line(line)
if next_token:
sentence.append(next_token)

# If there is a pause token, send the sentence to the TTS queue
if next_token in [".", "!", "?", ":", ";"]:
self._process_sentence(sentence)
sentence = []
if self.processing:
if sentence:
self.tts_queue.put(sentence)
self._process_sentence(sentence)

self.tts_queue.put("<EOS>") # Add end of stream token to the queue
except queue.Empty:
time.sleep(0.1)
time.sleep(PAUSE_TIME)

def _process_sentence(self, current_sentence):
"""
Join text, remove inflections and actions, and send to the TTS queue.
The LLM like to *whisper* things or (scream) things, and prompting is not a 100% fix.
We use regular expressions to remove text between ** and () to clean up the text.
"""
sentence = "".join(current_sentence)
sentence = sentence.removesuffix(LLM_STOP_SEQUENCE)
sentence = re.sub(r"\*.*?\*|\(.*?\)", "", sentence)

if sentence:
self.tts_queue.put(sentence)

def _process_line(self, line):
"""
Processes a single line of text from the LLM server.
Args:
line (dict): The line of text from the LLM server.
"""

if not line["choices"][0]["finish_reason"] == "stop":
token = line["choices"][0]["delta"]["content"]
return token
return None

def _clean_raw_bytes(self, line):
"""
Cleans the raw bytes from the LLM server for processing.
Coverts the bytes to a dictionary.
Args:
line (bytes): The raw bytes from the LLM server.
"""
line = line.decode("utf-8")
line = line.removeprefix("data: ")
line = json.loads(line)
return line


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit a1f3813

Please sign in to comment.