Skip to content

Commit

Permalink
新的opus封装以及优化const会导致的内存拷贝
Browse files Browse the repository at this point in the history
  • Loading branch information
78 committed Dec 3, 2024
1 parent 9c1f8a1 commit bcfd120
Show file tree
Hide file tree
Showing 15 changed files with 102 additions and 102 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16)

set(PROJECT_VER "0.9.5")
set(PROJECT_VER "0.9.6")

include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi)
97 changes: 47 additions & 50 deletions main/application.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ extern const char p3_err_pin_end[] asm("_binary_err_pin_p3_end");
extern const char p3_err_wificonfig_start[] asm("_binary_err_wificonfig_p3_start");
extern const char p3_err_wificonfig_end[] asm("_binary_err_wificonfig_p3_end");

static const char* const STATE_STRINGS[] = {
"unknown",
"idle",
"connecting",
"listening",
"speaking",
"upgrading",
"invalid_state"
};

Application::Application() : background_task_(4096 * 8) {
event_group_ = xEventGroupCreate();
Expand All @@ -30,13 +39,6 @@ Application::Application() : background_task_(4096 * 8) {
}

Application::~Application() {
if (protocol_ != nullptr) {
delete protocol_;
}
if (opus_decoder_ != nullptr) {
opus_decoder_destroy(opus_decoder_);
}

vEventGroupDelete(event_group_);
}

Expand Down Expand Up @@ -83,7 +85,7 @@ void Application::CheckNewVersion() {
}
}

void Application::Alert(const std::string&& title, const std::string&& message) {
void Application::Alert(const std::string& title, const std::string& message) {
ESP_LOGW(TAG, "Alert: %s, %s", title.c_str(), message.c_str());
auto display = Board::GetInstance().GetDisplay();
display->ShowNotification(message);
Expand All @@ -105,7 +107,7 @@ void Application::PlayLocalFile(const char* data, size_t size) {
p += sizeof(BinaryProtocol3);

auto payload_size = ntohs(p3->payload_size);
std::string opus;
std::vector<uint8_t> opus;
opus.resize(payload_size);
memcpy(opus.data(), p3->payload, payload_size);
p += payload_size;
Expand All @@ -117,10 +119,15 @@ void Application::PlayLocalFile(const char* data, size_t size) {

void Application::ToggleChatState() {
Schedule([this]() {
if (!protocol_) {
ESP_LOGE(TAG, "Protocol not initialized");
return;
}

if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting);
if (!protocol_->OpenAudioChannel()) {
ESP_LOGE(TAG, "Failed to open audio channel");
Alert("Error", "Failed to open audio channel");
SetChatState(kChatStateIdle);
return;
}
Expand All @@ -138,13 +145,18 @@ void Application::ToggleChatState() {

void Application::StartListening() {
Schedule([this]() {
if (!protocol_) {
ESP_LOGE(TAG, "Protocol not initialized");
return;
}

keep_listening_ = false;
if (chat_state_ == kChatStateIdle) {
if (!protocol_->IsAudioChannelOpened()) {
SetChatState(kChatStateConnecting);
if (!protocol_->OpenAudioChannel()) {
SetChatState(kChatStateIdle);
ESP_LOGE(TAG, "Failed to open audio channel");
Alert("Error", "Failed to open audio channel");
return;
}
}
Expand Down Expand Up @@ -183,8 +195,8 @@ void Application::Start() {
/* Setup the audio codec */
auto codec = board.GetAudioCodec();
opus_decode_sample_rate_ = codec->output_sample_rate();
opus_decoder_ = opus_decoder_create(opus_decode_sample_rate_, 1, NULL);
opus_encoder_.Configure(16000, 1, OPUS_FRAME_DURATION_MS);
opus_decoder_ = std::make_unique<OpusDecoderWrapper>(opus_decode_sample_rate_, 1);
opus_encoder_ = std::make_unique<OpusEncoderWrapper>(16000, 1, OPUS_FRAME_DURATION_MS);
if (codec->input_sample_rate() != 16000) {
input_resampler_.Configure(codec->input_sample_rate(), 16000);
reference_resampler_.Configure(codec->input_sample_rate(), 16000);
Expand Down Expand Up @@ -221,9 +233,9 @@ void Application::Start() {
#if CONFIG_IDF_TARGET_ESP32S3
audio_processor_.Initialize(codec->input_channels(), codec->input_reference());
audio_processor_.OnOutput([this](std::vector<int16_t>&& data) {
background_task_.Schedule([this, data = std::move(data)]() {
opus_encoder_.Encode(data, [this](const uint8_t* opus, size_t opus_size) {
Schedule([this, opus = std::string(reinterpret_cast<const char*>(opus), opus_size)]() {
background_task_.Schedule([this, data = std::move(data)]() mutable {
opus_encoder_->Encode(std::move(data), [this](std::vector<uint8_t>&& opus) {
Schedule([this, opus = std::move(opus)]() {
protocol_->SendAudio(opus);
});
});
Expand Down Expand Up @@ -258,7 +270,7 @@ void Application::Start() {
return;
}

std::string opus;
std::vector<uint8_t> opus;
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
Expand All @@ -282,14 +294,14 @@ void Application::Start() {
// Initialize the protocol
display->SetStatus("初始化协议");
#ifdef CONFIG_CONNECTION_TYPE_WEBSOCKET
protocol_ = new WebsocketProtocol();
protocol_ = std::make_unique<WebsocketProtocol>();
#else
protocol_ = new MqttProtocol();
protocol_ = std::make_unique<MqttProtocol>();
#endif
protocol_->OnNetworkError([this](const std::string& message) {
Alert("Error", std::move(message));
});
protocol_->OnIncomingAudio([this](const std::string& data) {
protocol_->OnIncomingAudio([this](std::vector<uint8_t>&& data) {
std::lock_guard<std::mutex> lock(mutex_);
if (chat_state_ == kChatStateSpeaking) {
audio_decode_queue_.emplace_back(std::move(data));
Expand Down Expand Up @@ -363,9 +375,8 @@ void Application::Start() {
}

void Application::Schedule(std::function<void()> callback) {
mutex_.lock();
main_tasks_.push_back(callback);
mutex_.unlock();
std::lock_guard<std::mutex> lock(mutex_);
main_tasks_.push_back(std::move(callback));
xEventGroupSetBits(event_group_, SCHEDULE_EVENT);
}

Expand Down Expand Up @@ -397,7 +408,7 @@ void Application::MainLoop() {

void Application::ResetDecoder() {
std::lock_guard<std::mutex> lock(mutex_);
opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE);
opus_decoder_->ResetState();
audio_decode_queue_.clear();
last_output_time_ = std::chrono::steady_clock::now();
Board::GetInstance().GetAudioCodec()->EnableOutput(true);
Expand Down Expand Up @@ -430,24 +441,21 @@ void Application::OutputAudio() {
audio_decode_queue_.pop_front();
lock.unlock();

background_task_.Schedule([this, codec, opus = std::move(opus)]() {
background_task_.Schedule([this, codec, opus = std::move(opus)]() mutable {
if (aborted_) {
return;
}
int frame_size = opus_decode_sample_rate_ * OPUS_FRAME_DURATION_MS / 1000;
std::vector<int16_t> pcm(frame_size);

int ret = opus_decode(opus_decoder_, (const unsigned char*)opus.data(), opus.size(), pcm.data(), frame_size, 0);
if (ret < 0) {
ESP_LOGE(TAG, "Failed to decode audio, error code: %d", ret);
std::vector<int16_t> pcm;
if (!opus_decoder_->Decode(std::move(opus), pcm)) {
return;
}

// Resample if the sample rate is different
if (opus_decode_sample_rate_ != codec->output_sample_rate()) {
int target_size = output_resampler_.GetOutputSamples(frame_size);
int target_size = output_resampler_.GetOutputSamples(pcm.size());
std::vector<int16_t> resampled(target_size);
output_resampler_.Process(pcm.data(), frame_size, resampled.data());
output_resampler_.Process(pcm.data(), pcm.size(), resampled.data());
pcm = std::move(resampled);
}

Expand Down Expand Up @@ -495,9 +503,9 @@ void Application::InputAudio() {
}
#else
if (chat_state_ == kChatStateListening) {
background_task_.Schedule([this, data = std::move(data)]() {
opus_encoder_.Encode(data, [this](const uint8_t* opus, size_t opus_size) {
Schedule([this, opus = std::string(reinterpret_cast<const char*>(opus), opus_size)]() {
background_task_.Schedule([this, data = std::move(data)]() mutable {
opus_encoder_->Encode(std::move(data), [this](std::vector<uint8_t>&& opus) {
Schedule([this, opus = std::move(opus)]() {
protocol_->SendAudio(opus);
});
});
Expand All @@ -513,22 +521,12 @@ void Application::AbortSpeaking(AbortReason reason) {
}

void Application::SetChatState(ChatState state) {
const char* state_str[] = {
"unknown",
"idle",
"connecting",
"listening",
"speaking",
"upgrading",
"invalid_state"
};
if (chat_state_ == state) {
// No need to update the state
return;
}

chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);
ESP_LOGI(TAG, "STATE: %s", STATE_STRINGS[chat_state_]);
// The state is changed, wait for all background tasks to finish
background_task_.WaitForCompletion();

Expand All @@ -555,7 +553,7 @@ void Application::SetChatState(ChatState state) {
display->SetStatus("聆听中...");
display->SetEmotion("neutral");
ResetDecoder();
opus_encoder_.ResetState();
opus_encoder_->ResetState();
#if CONFIG_IDF_TARGET_ESP32S3
audio_processor_.Start();
#endif
Expand Down Expand Up @@ -584,9 +582,8 @@ void Application::SetDecodeSampleRate(int sample_rate) {
return;
}

opus_decoder_destroy(opus_decoder_);
opus_decode_sample_rate_ = sample_rate;
opus_decoder_ = opus_decoder_create(opus_decode_sample_rate_, 1, NULL);
opus_decoder_ = std::make_unique<OpusDecoderWrapper>(opus_decode_sample_rate_, 1);

auto codec = Board::GetInstance().GetAudioCodec();
if (opus_decode_sample_rate_ != codec->output_sample_rate()) {
Expand Down
11 changes: 6 additions & 5 deletions main/application.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <condition_variable>

#include "opus_encoder.h"
#include "opus_decoder.h"
#include "opus_resampler.h"

#include "protocol.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ class Application {
ChatState GetChatState() const { return chat_state_; }
void Schedule(std::function<void()> callback);
void SetChatState(ChatState state);
void Alert(const std::string&& title, const std::string&& message);
void Alert(const std::string& title, const std::string& message);
void AbortSpeaking(AbortReason reason);
void ToggleChatState();
void StartListening();
Expand All @@ -69,7 +70,7 @@ class Application {
Ota ota_;
std::mutex mutex_;
std::list<std::function<void()>> main_tasks_;
Protocol* protocol_ = nullptr;
std::unique_ptr<Protocol> protocol_;
EventGroupHandle_t event_group_;
volatile ChatState chat_state_ = kChatStateUnknown;
bool keep_listening_ = false;
Expand All @@ -78,10 +79,10 @@ class Application {
// Audio encode / decode
BackgroundTask background_task_;
std::chrono::steady_clock::time_point last_output_time_;
std::list<std::string> audio_decode_queue_;
std::list<std::vector<uint8_t>> audio_decode_queue_;

OpusEncoder opus_encoder_;
OpusDecoder* opus_decoder_ = nullptr;
std::unique_ptr<OpusEncoderWrapper> opus_encoder_;
std::unique_ptr<OpusDecoderWrapper> opus_decoder_;

int opus_decode_sample_rate_ = -1;
OpusResampler input_resampler_;
Expand Down
2 changes: 1 addition & 1 deletion main/audio_processing/audio_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ AudioProcessor::~AudioProcessor() {
vEventGroupDelete(event_group_);
}

void AudioProcessor::Input(std::vector<int16_t>& data) {
void AudioProcessor::Input(const std::vector<int16_t>& data) {
input_buffer_.insert(input_buffer_.end(), data.begin(), data.end());

auto chunk_size = esp_afe_vc_v1.get_feed_chunksize(afe_communication_data_) * channels_;
Expand Down
2 changes: 1 addition & 1 deletion main/audio_processing/audio_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class AudioProcessor {
~AudioProcessor();

void Initialize(int channels, bool reference);
void Input(std::vector<int16_t>& data);
void Input(const std::vector<int16_t>& data);
void Start();
void Stop();
bool IsRunning();
Expand Down
44 changes: 21 additions & 23 deletions main/audio_processing/wake_word_detect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ bool WakeWordDetect::IsDetectionRunning() {
return xEventGroupGetBits(event_group_) & DETECTION_RUNNING_EVENT;
}

void WakeWordDetect::Feed(std::vector<int16_t>& data) {
void WakeWordDetect::Feed(const std::vector<int16_t>& data) {
input_buffer_.insert(input_buffer_.end(), data.begin(), data.end());

auto chunk_size = esp_afe_sr_v1.get_feed_chunksize(afe_detection_data_) * channels_;
Expand Down Expand Up @@ -163,8 +163,7 @@ void WakeWordDetect::AudioDetectionTask() {

void WakeWordDetect::StoreWakeWordData(uint16_t* data, size_t samples) {
// store audio data to wake_word_pcm_
std::vector<int16_t> pcm(data, data + samples);
wake_word_pcm_.emplace_back(std::move(pcm));
wake_word_pcm_.emplace_back(std::vector<int16_t>(data, data + samples));
// keep about 2 seconds of data, detect duration is 32ms (sample_rate == 16000, chunksize == 512)
while (wake_word_pcm_.size() > 2000 / 32) {
wake_word_pcm_.pop_front();
Expand All @@ -178,34 +177,33 @@ void WakeWordDetect::EncodeWakeWordData() {
}
wake_word_encode_task_ = xTaskCreateStatic([](void* arg) {
auto this_ = (WakeWordDetect*)arg;
auto start_time = esp_timer_get_time();
// encode detect packets
OpusEncoder* encoder = new OpusEncoder();
encoder->Configure(16000, 1, 60);
encoder->SetComplexity(0);

for (auto& pcm: this_->wake_word_pcm_) {
encoder->Encode(pcm, [this_](const uint8_t* opus, size_t opus_size) {
std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
this_->wake_word_opus_.emplace_back(std::string(reinterpret_cast<const char*>(opus), opus_size));
this_->wake_word_cv_.notify_all();
});
}
this_->wake_word_pcm_.clear();

auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms", this_->wake_word_opus_.size(), (end_time - start_time) / 1000);
{
auto start_time = esp_timer_get_time();
auto encoder = std::make_unique<OpusEncoderWrapper>(16000, 1, OPUS_FRAME_DURATION_MS);
encoder->SetComplexity(0); // 0 is the fastest

for (auto& pcm: this_->wake_word_pcm_) {
encoder->Encode(std::move(pcm), [this_](std::vector<uint8_t>&& opus) {
std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
this_->wake_word_opus_.emplace_back(std::move(opus));
this_->wake_word_cv_.notify_all();
});
}
this_->wake_word_pcm_.clear();

auto end_time = esp_timer_get_time();
ESP_LOGI(TAG, "Encode wake word opus %zu packets in %lld ms",
this_->wake_word_opus_.size(), (end_time - start_time) / 1000);

std::lock_guard<std::mutex> lock(this_->wake_word_mutex_);
this_->wake_word_opus_.push_back("");
this_->wake_word_opus_.push_back(std::vector<uint8_t>());
this_->wake_word_cv_.notify_all();
}
delete encoder;
vTaskDelete(NULL);
}, "encode_detect_packets", 4096 * 8, this, 1, wake_word_encode_task_stack_, &wake_word_encode_task_buffer_);
}

bool WakeWordDetect::GetWakeWordOpus(std::string& opus) {
bool WakeWordDetect::GetWakeWordOpus(std::vector<uint8_t>& opus) {
std::unique_lock<std::mutex> lock(wake_word_mutex_);
wake_word_cv_.wait(lock, [this]() {
return !wake_word_opus_.empty();
Expand Down
Loading

0 comments on commit bcfd120

Please sign in to comment.