Skip to content

Commit

Permalink
update protocol to support manual response mode
Browse files Browse the repository at this point in the history
  • Loading branch information
78 committed Nov 24, 2024
1 parent aa806f6 commit 472219d
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 100 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# CMakeLists in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.16)

set(PROJECT_VER "0.9.1")
set(PROJECT_VER "0.9.2")

include($ENV{IDF_PATH}/tools/cmake/project.cmake)
project(xiaozhi)
109 changes: 76 additions & 33 deletions main/application.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,52 @@ void Application::ToggleChatState() {
Schedule([this]() {
if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting);
if (protocol_->OpenAudioChannel()) {
opus_encoder_.ResetState();
SetChatState(kChatStateListening);
} else {
if (!protocol_->OpenAudioChannel()) {
ESP_LOGE(TAG, "Failed to open audio channel");
SetChatState(kChatStateIdle);
return;
}

keep_listening_ = true;
protocol_->SendStartListening(kListeningModeAutoStop);
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking();
AbortSpeaking(kAbortReasonNone);
} else if (chat_state_ == kChatStateListening) {
protocol_->CloseAudioChannel();
}
});
}

void Application::StartListening() {
Schedule([this]() {
keep_listening_ = false;
if (chat_state_ == kChatStateIdle) {
if (!protocol_->IsAudioChannelOpened()) {
SetChatState(kChatStateConnecting);
if (!protocol_->OpenAudioChannel()) {
SetChatState(kChatStateIdle);
ESP_LOGE(TAG, "Failed to open audio channel");
return;
}
}
protocol_->SendStartListening(kListeningModeManualStop);
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking(kAbortReasonNone);
protocol_->SendStartListening(kListeningModeManualStop);
SetChatState(kChatStateListening);
}
});
}

void Application::StopListening() {
Schedule([this]() {
protocol_->SendStopListening();
SetChatState(kChatStateIdle);
});
}

void Application::Start() {
auto& board = Board::GetInstance();
board.Initialize();
Expand Down Expand Up @@ -248,26 +280,31 @@ void Application::Start() {
});
});

wake_word_detect_.OnWakeWordDetected([this]() {
Schedule([this]() {
wake_word_detect_.OnWakeWordDetected([this](const std::string& wake_word) {
Schedule([this, &wake_word]() {
if (chat_state_ == kChatStateIdle) {
SetChatState(kChatStateConnecting);
wake_word_detect_.EncodeWakeWordData();

if (protocol_->OpenAudioChannel()) {
std::string opus;
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
}
opus_encoder_.ResetState();
// Send a ready message to indicate the server that the wake word data is sent
SetChatState(kChatStateWakeWordDetected);
} else {
if (!protocol_->OpenAudioChannel()) {
ESP_LOGE(TAG, "Failed to open audio channel");
SetChatState(kChatStateIdle);
wake_word_detect_.StartDetection();
return;
}

std::string opus;
// Encode and send the wake word data to the server
while (wake_word_detect_.GetWakeWordOpus(opus)) {
protocol_->SendAudio(opus);
}
// Set the chat state to wake word detected
protocol_->SendWakeWordDetected(wake_word);
ESP_LOGI(TAG, "Wake word detected: %s", wake_word.c_str());
keep_listening_ = true;
SetChatState(kChatStateListening);
} else if (chat_state_ == kChatStateSpeaking) {
AbortSpeaking();
AbortSpeaking(kAbortReasonWakeWordDetected);
}

// Resume detection
Expand Down Expand Up @@ -313,15 +350,23 @@ void Application::Start() {
auto state = cJSON_GetObjectItem(root, "state");
if (strcmp(state->valuestring, "start") == 0) {
Schedule([this]() {
skip_to_end_ = false;
SetChatState(kChatStateSpeaking);
if (chat_state_ == kChatStateIdle || chat_state_ == kChatStateListening) {
skip_to_end_ = false;
opus_decoder_ctl(opus_decoder_, OPUS_RESET_STATE);
SetChatState(kChatStateSpeaking);
}
});
} else if (strcmp(state->valuestring, "stop") == 0) {
Schedule([this]() {
auto codec = Board::GetInstance().GetAudioCodec();
codec->WaitForOutputDone();
if (chat_state_ == kChatStateSpeaking) {
SetChatState(kChatStateListening);
if (keep_listening_) {
protocol_->SendStartListening(kListeningModeAutoStop);
SetChatState(kChatStateListening);
} else {
SetChatState(kChatStateIdle);
}
}
});
} else if (strcmp(state->valuestring, "sentence_start") == 0) {
Expand Down Expand Up @@ -375,9 +420,9 @@ void Application::MainLoop() {
}
}

void Application::AbortSpeaking() {
void Application::AbortSpeaking(AbortReason reason) {
ESP_LOGI(TAG, "Abort speaking");
protocol_->SendAbort();
protocol_->SendAbortSpeaking(reason);

skip_to_end_ = true;
auto codec = Board::GetInstance().GetAudioCodec();
Expand All @@ -391,20 +436,17 @@ void Application::SetChatState(ChatState state) {
"connecting",
"listening",
"speaking",
"wake_word_detected",
"upgrading",
"invalid_state"
};
if (chat_state_ == state) {
// No need to update the state
return;
}
chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);

auto display = Board::GetInstance().GetDisplay();
auto builtin_led = Board::GetInstance().GetBuiltinLed();
switch (chat_state_) {
switch (state) {
case kChatStateUnknown:
case kChatStateIdle:
builtin_led->TurnOff();
Expand All @@ -424,6 +466,7 @@ void Application::SetChatState(ChatState state) {
builtin_led->TurnOn();
display->SetStatus("聆听中...");
display->SetEmotion("neutral");
opus_encoder_.ResetState();
#ifdef CONFIG_USE_AFE_SR
audio_processor_.Start();
#endif
Expand All @@ -436,17 +479,17 @@ void Application::SetChatState(ChatState state) {
audio_processor_.Stop();
#endif
break;
case kChatStateWakeWordDetected:
builtin_led->SetBlue();
builtin_led->TurnOn();
break;
case kChatStateUpgrading:
builtin_led->SetGreen();
builtin_led->StartContinuousBlink(100);
break;
default:
ESP_LOGE(TAG, "Invalid chat state: %d", chat_state_);
return;
}

protocol_->SendState(state_str[chat_state_]);
chat_state_ = state;
ESP_LOGI(TAG, "STATE: %s", state_str[chat_state_]);
}

void Application::AudioEncodeTask() {
Expand Down Expand Up @@ -474,7 +517,7 @@ void Application::AudioEncodeTask() {
audio_decode_queue_.pop_front();
lock.unlock();

if (skip_to_end_) {
if (skip_to_end_ || chat_state_ != kChatStateSpeaking) {
continue;
}

Expand Down
12 changes: 7 additions & 5 deletions main/application.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ enum ChatState {
kChatStateConnecting,
kChatStateListening,
kChatStateSpeaking,
kChatStateWakeWordDetected,
kChatStateUpgrading
};

Expand All @@ -41,17 +40,19 @@ class Application {
static Application instance;
return instance;
}
// 删除拷贝构造函数和赋值运算符
Application(const Application&) = delete;
Application& operator=(const Application&) = delete;

void Start();
ChatState GetChatState() const { return chat_state_; }
void Schedule(std::function<void()> callback);
void SetChatState(ChatState state);
void Alert(const std::string&& title, const std::string&& message);
void AbortSpeaking();
void AbortSpeaking(AbortReason reason);
void ToggleChatState();
// 删除拷贝构造函数和赋值运算符
Application(const Application&) = delete;
Application& operator=(const Application&) = delete;
void StartListening();
void StopListening();

private:
Application();
Expand All @@ -68,6 +69,7 @@ class Application {
Protocol* protocol_ = nullptr;
EventGroupHandle_t event_group_;
volatile ChatState chat_state_ = kChatStateUnknown;
bool keep_listening_ = false;
bool skip_to_end_ = false;

// Audio encode / decode
Expand Down
23 changes: 3 additions & 20 deletions main/protocols/mqtt_protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ bool MqttProtocol::StartMqttClient() {
} else if (strcmp(type->valuestring, "goodbye") == 0) {
auto session_id = cJSON_GetObjectItem(root, "session_id");
if (session_id == nullptr || session_id_ == session_id->valuestring) {
if (on_audio_channel_closed_ != nullptr) {
on_audio_channel_closed_();
}
Application::GetInstance().Schedule([this]() {
CloseAudioChannel();
});
}
} else if (on_incoming_json_ != nullptr) {
on_incoming_json_(root);
Expand Down Expand Up @@ -129,23 +129,6 @@ void MqttProtocol::SendAudio(const std::string& data) {
udp_->Send(encrypted);
}

void MqttProtocol::SendState(const std::string& state) {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"state\",";
message += "\"state\":\"" + state + "\"";
message += "}";
SendText(message);
}

void MqttProtocol::SendAbort() {
std::string message = "{";
message += "\"session_id\":\"" + session_id_ + "\",";
message += "\"type\":\"abort\"";
message += "}";
SendText(message);
}

void MqttProtocol::CloseAudioChannel() {
{
std::lock_guard<std::mutex> lock(channel_mutex_);
Expand Down
6 changes: 2 additions & 4 deletions main/protocols/mqtt_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class MqttProtocol : public Protocol {
~MqttProtocol();

void SendAudio(const std::string& data) override;
void SendText(const std::string& text) override;
void SendState(const std::string& state) override;
void SendAbort() override;
bool OpenAudioChannel() override;
void CloseAudioChannel() override;
bool IsAudioChannelOpened() const override;
Expand All @@ -52,11 +49,12 @@ class MqttProtocol : public Protocol {
int udp_port_;
uint32_t local_sequence_;
uint32_t remote_sequence_;
std::string session_id_;

bool StartMqttClient();
void ParseServerHello(const cJSON* root);
std::string DecodeHexString(const std::string& hex_string);

void SendText(const std::string& text) override;
};


Expand Down
34 changes: 34 additions & 0 deletions main/protocols/protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,37 @@ void Protocol::OnAudioChannelClosed(std::function<void()> callback) {
void Protocol::OnNetworkError(std::function<void(const std::string& message)> callback) {
on_network_error_ = callback;
}

void Protocol::SendAbortSpeaking(AbortReason reason) {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"abort\"";
if (reason == kAbortReasonWakeWordDetected) {
message += ",\"reason\":\"wake_word_detected\"";
}
message += "}";
SendText(message);
}

void Protocol::SendWakeWordDetected(const std::string& wake_word) {
std::string json = "{\"session_id\":\"" + session_id_ +
"\",\"type\":\"listen\",\"state\":\"detect\",\"text\":\"" + wake_word + "\"}";
SendText(json);
}

void Protocol::SendStartListening(ListeningMode mode) {
std::string message = "{\"session_id\":\"" + session_id_ + "\"";
message += ",\"type\":\"listen\",\"state\":\"start\"";
if (mode == kListeningModeAlwaysOn) {
message += ",\"mode\":\"realtime\"";
} else if (mode == kListeningModeAutoStop) {
message += ",\"mode\":\"auto\"";
} else {
message += ",\"mode\":\"manual\"";
}
message += "}";
SendText(message);
}

void Protocol::SendStopListening() {
std::string message = "{\"session_id\":\"" + session_id_ + "\",\"type\":\"listen\",\"state\":\"stop\"}";
SendText(message);
}
22 changes: 18 additions & 4 deletions main/protocols/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ struct BinaryProtocol3 {
uint8_t payload[];
} __attribute__((packed));

enum AbortReason {
kAbortReasonNone,
kAbortReasonWakeWordDetected
};

enum ListeningMode {
kListeningModeAutoStop,
kListeningModeManualStop,
kListeningModeAlwaysOn // 需要 AEC 支持
};

class Protocol {
public:
Expand All @@ -27,13 +37,14 @@ class Protocol {
void OnAudioChannelClosed(std::function<void()> callback);
void OnNetworkError(std::function<void(const std::string& message)> callback);

virtual void SendAudio(const std::string& data) = 0;
virtual void SendText(const std::string& text) = 0;
virtual void SendState(const std::string& state) = 0;
virtual void SendAbort() = 0;
virtual bool OpenAudioChannel() = 0;
virtual void CloseAudioChannel() = 0;
virtual bool IsAudioChannelOpened() const = 0;
virtual void SendAudio(const std::string& data) = 0;
virtual void SendWakeWordDetected(const std::string& wake_word);
virtual void SendStartListening(ListeningMode mode);
virtual void SendStopListening();
virtual void SendAbortSpeaking(AbortReason reason);

protected:
std::function<void(const cJSON* root)> on_incoming_json_;
Expand All @@ -43,6 +54,9 @@ class Protocol {
std::function<void(const std::string& message)> on_network_error_;

int server_sample_rate_ = 16000;
std::string session_id_;

virtual void SendText(const std::string& text) = 0;
};

#endif // PROTOCOL_H
Expand Down
Loading

0 comments on commit 472219d

Please sign in to comment.