Skip to content

Commit

Permalink
Fix API socket issues (esphome#2288)
Browse files Browse the repository at this point in the history
* Fix API socket issues

* Fix compile error against beta

* Format
  • Loading branch information
OttoWinter authored Sep 13, 2021
1 parent 40c474c commit ed7983a
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 77 deletions.
74 changes: 49 additions & 25 deletions esphome/components/api/api_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,14 @@ void APIConnection::start() {

APIError err = helper_->init();
if (err != APIError::OK) {
ESP_LOGW(TAG, "Helper init failed: %d errno=%d", (int) err, errno);
remove_ = true;
on_fatal_error();
ESP_LOGW(TAG, "%s: Helper init failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno);
return;
}
client_info_ = helper_->getpeername();
helper_->set_log_info(client_info_);
}

void APIConnection::force_disconnect_client() {
this->helper_->close();
this->remove_ = true;
}

void APIConnection::loop() {
if (this->remove_)
return;
Expand All @@ -57,9 +52,11 @@ void APIConnection::loop() {
// when network is disconnected force disconnect immediately
// don't wait for timeout
this->on_fatal_error();
ESP_LOGW(TAG, "%s: Network unavailable, disconnecting", client_info_.c_str());
return;
}
if (this->next_close_) {
// requested a disconnect
this->helper_->close();
this->remove_ = true;
return;
Expand All @@ -68,7 +65,7 @@ void APIConnection::loop() {
APIError err = helper_->loop();
if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Socket operation failed: %d", client_info_.c_str(), (int) err);
ESP_LOGW(TAG, "%s: Socket operation failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno);
return;
}
ReadPacketBuffer buffer;
Expand All @@ -77,7 +74,11 @@ void APIConnection::loop() {
// pass
} else if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Reading failed: %d", client_info_.c_str(), (int) err);
if (err == APIError::SOCKET_READ_FAILED && errno == ECONNRESET) {
ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str());
} else {
ESP_LOGW(TAG, "%s: Reading failed: %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno);
}
return;
} else {
this->last_traffic_ = millis();
Expand All @@ -95,8 +96,8 @@ void APIConnection::loop() {
if (this->sent_ping_) {
// Disconnect if not responded within 2.5*keepalive
if (now - this->last_traffic_ > (keepalive * 5) / 2) {
this->force_disconnect_client();
ESP_LOGW(TAG, "'%s' didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
on_fatal_error();
ESP_LOGW(TAG, "%s didn't respond to ping request in time. Disconnecting...", this->client_info_.c_str());
}
} else if (now - this->last_traffic_ > keepalive) {
this->sent_ping_ = true;
Expand Down Expand Up @@ -124,12 +125,40 @@ void APIConnection::loop() {
}
}
#endif

if (state_subs_at_ != -1) {
const auto &subs = this->parent_->get_state_subs();
if (state_subs_at_ >= subs.size()) {
state_subs_at_ = -1;
} else {
auto &it = subs[state_subs_at_];
SubscribeHomeAssistantStateResponse resp;
resp.entity_id = it.entity_id;
resp.attribute = it.attribute.value();
if (this->send_subscribe_home_assistant_state_response(resp)) {
state_subs_at_++;
}
}
}
}

std::string get_default_unique_id(const std::string &component_type, Nameable *nameable) {
return App.get_name() + component_type + nameable->get_object_id();
}

DisconnectResponse APIConnection::disconnect(const DisconnectRequest &msg) {
// remote initiated disconnect_client
// don't close yet, we still need to send the disconnect response
// close will happen on next loop
ESP_LOGD(TAG, "%s requested disconnected", client_info_.c_str());
this->next_close_ = true;
DisconnectResponse resp;
return resp;
}
void APIConnection::on_disconnect_response(const DisconnectResponse &value) {
// pass
}

#ifdef USE_BINARY_SENSOR
bool APIConnection::send_binary_sensor_state(binary_sensor::BinarySensor *binary_sensor, bool state) {
if (!this->state_subscription_)
Expand Down Expand Up @@ -703,7 +732,7 @@ ConnectResponse APIConnection::connect(const ConnectRequest &msg) {
// bool invalid_password = 1;
resp.invalid_password = !correct;
if (correct) {
ESP_LOGD(TAG, "Client '%s' connected successfully!", this->client_info_.c_str());
ESP_LOGD(TAG, "%s: Connected successfully", this->client_info_.c_str());
this->connection_state_ = ConnectionState::AUTHENTICATED;

#ifdef USE_HOMEASSISTANT_TIME
Expand Down Expand Up @@ -749,15 +778,7 @@ void APIConnection::execute_service(const ExecuteServiceRequest &msg) {
}
}
void APIConnection::subscribe_home_assistant_states(const SubscribeHomeAssistantStatesRequest &msg) {
for (auto &it : this->parent_->get_state_subs()) {
SubscribeHomeAssistantStateResponse resp;
resp.entity_id = it.entity_id;
resp.attribute = it.attribute.value();
if (!this->send_subscribe_home_assistant_state_response(resp)) {
this->on_fatal_error();
return;
}
}
state_subs_at_ = 0;
}
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) {
if (this->remove_)
Expand All @@ -770,22 +791,25 @@ bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint32_t message_type)
return false;
if (err != APIError::OK) {
on_fatal_error();
ESP_LOGW(TAG, "%s: Packet write failed %d errno=%d", client_info_.c_str(), (int) err, errno);
if (err == APIError::SOCKET_WRITE_FAILED && errno == ECONNRESET) {
ESP_LOGW(TAG, "%s: Connection reset", client_info_.c_str());
} else {
ESP_LOGW(TAG, "%s: Packet write failed %s errno=%d", client_info_.c_str(), api_error_to_str(err), errno);
}
return false;
}
this->last_traffic_ = millis();
return true;
}
void APIConnection::on_unauthenticated_access() {
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without authentication.", this->client_info_.c_str());
ESP_LOGD(TAG, "%s: tried to access without authentication.", this->client_info_.c_str());
}
void APIConnection::on_no_setup_connection() {
this->on_fatal_error();
ESP_LOGD(TAG, "'%s' tried to access without full connection.", this->client_info_.c_str());
ESP_LOGD(TAG, "%s: tried to access without full connection.", this->client_info_.c_str());
}
void APIConnection::on_fatal_error() {
ESP_LOGV(TAG, "Error: Disconnecting %s", this->client_info_.c_str());
this->helper_->close();
this->remove_ = true;
}
Expand Down
16 changes: 3 additions & 13 deletions esphome/components/api/api_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class APIConnection : public APIServerConnection {
virtual ~APIConnection() = default;

void start();
void force_disconnect_client();
void loop();

bool send_list_info_done() {
Expand Down Expand Up @@ -88,10 +87,7 @@ class APIConnection : public APIServerConnection {
}
#endif

void on_disconnect_response(const DisconnectResponse &value) override {
this->helper_->close();
this->remove_ = true;
}
void on_disconnect_response(const DisconnectResponse &value) override;
void on_ping_response(const PingResponse &value) override {
// we initiated ping
this->sent_ping_ = false;
Expand All @@ -102,14 +98,7 @@ class APIConnection : public APIServerConnection {
#endif
HelloResponse hello(const HelloRequest &msg) override;
ConnectResponse connect(const ConnectRequest &msg) override;
DisconnectResponse disconnect(const DisconnectRequest &msg) override {
// remote initiated disconnect_client
// don't close yet, we still need to send the disconnect response
// close will happen on next loop
this->next_close_ = true;
DisconnectResponse resp;
return resp;
}
DisconnectResponse disconnect(const DisconnectRequest &msg) override;
PingResponse ping(const PingRequest &msg) override { return {}; }
DeviceInfoResponse device_info(const DeviceInfoRequest &msg) override;
void list_entities(const ListEntitiesRequest &msg) override { this->list_entities_iterator_.begin(); }
Expand Down Expand Up @@ -177,6 +166,7 @@ class APIConnection : public APIServerConnection {
APIServer *parent_;
InitialStateIterator initial_state_iterator_;
ListEntitiesIterator list_entities_iterator_;
int state_subs_at_ = -1;
};

} // namespace api
Expand Down
70 changes: 51 additions & 19 deletions esphome/components/api/api_frame_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,54 @@ bool is_would_block(ssize_t ret) {
return ret == 0;
}

const char *api_error_to_str(APIError err) {
// not using switch to ensure compiler doesn't try to build a big table out of it
if (err == APIError::OK) {
return "OK";
} else if (err == APIError::WOULD_BLOCK) {
return "WOULD_BLOCK";
} else if (err == APIError::BAD_HANDSHAKE_PACKET_LEN) {
return "BAD_HANDSHAKE_PACKET_LEN";
} else if (err == APIError::BAD_INDICATOR) {
return "BAD_INDICATOR";
} else if (err == APIError::BAD_DATA_PACKET) {
return "BAD_DATA_PACKET";
} else if (err == APIError::TCP_NODELAY_FAILED) {
return "TCP_NODELAY_FAILED";
} else if (err == APIError::TCP_NONBLOCKING_FAILED) {
return "TCP_NONBLOCKING_FAILED";
} else if (err == APIError::CLOSE_FAILED) {
return "CLOSE_FAILED";
} else if (err == APIError::SHUTDOWN_FAILED) {
return "SHUTDOWN_FAILED";
} else if (err == APIError::BAD_STATE) {
return "BAD_STATE";
} else if (err == APIError::BAD_ARG) {
return "BAD_ARG";
} else if (err == APIError::SOCKET_READ_FAILED) {
return "SOCKET_READ_FAILED";
} else if (err == APIError::SOCKET_WRITE_FAILED) {
return "SOCKET_WRITE_FAILED";
} else if (err == APIError::HANDSHAKESTATE_READ_FAILED) {
return "HANDSHAKESTATE_READ_FAILED";
} else if (err == APIError::HANDSHAKESTATE_WRITE_FAILED) {
return "HANDSHAKESTATE_WRITE_FAILED";
} else if (err == APIError::HANDSHAKESTATE_BAD_STATE) {
return "HANDSHAKESTATE_BAD_STATE";
} else if (err == APIError::CIPHERSTATE_DECRYPT_FAILED) {
return "CIPHERSTATE_DECRYPT_FAILED";
} else if (err == APIError::CIPHERSTATE_ENCRYPT_FAILED) {
return "CIPHERSTATE_ENCRYPT_FAILED";
} else if (err == APIError::OUT_OF_MEMORY) {
return "OUT_OF_MEMORY";
} else if (err == APIError::HANDSHAKESTATE_SETUP_FAILED) {
return "HANDSHAKESTATE_SETUP_FAILED";
} else if (err == APIError::HANDSHAKESTATE_SPLIT_FAILED) {
return "HANDSHAKESTATE_SPLIT_FAILED";
}
return "UNKNOWN";
}

#define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__)

#ifdef USE_API_NOISE
Expand Down Expand Up @@ -808,14 +856,12 @@ APIError APIPlaintextFrameHelper::try_send_tx_buf_() {
// try send from tx_buf
while (state_ != State::CLOSED && !tx_buf_.empty()) {
ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size());
if (sent == -1) {
if (errno == EWOULDBLOCK || errno == EAGAIN)
break;
if (is_would_block(sent)) {
break;
} else if (sent == -1) {
state_ = State::FAILED;
HELPER_LOG("Socket write failed with errno %d", errno);
return APIError::SOCKET_WRITE_FAILED;
} else if (sent == 0) {
break;
}
// TODO: inefficient if multiple packets in txbuf
// replace with deque of buffers
Expand Down Expand Up @@ -869,20 +915,6 @@ APIError APIPlaintextFrameHelper::write_raw_(const uint8_t *data, size_t len) {
// fully sent
return APIError::OK;
}
APIError APIPlaintextFrameHelper::write_frame_(const uint8_t *data, size_t len) {
APIError aerr;

uint8_t header[3];
header[0] = 0x01; // indicator
header[1] = (uint8_t)(len >> 8);
header[2] = (uint8_t) len;

aerr = write_raw_(header, 3);
if (aerr != APIError::OK)
return aerr;
aerr = write_raw_(data, len);
return aerr;
}

APIError APIPlaintextFrameHelper::close() {
state_ = State::CLOSED;
Expand Down
3 changes: 2 additions & 1 deletion esphome/components/api/api_frame_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ enum class APIError : int {
HANDSHAKESTATE_SPLIT_FAILED = 1020,
};

const char *api_error_to_str(APIError err);

class APIFrameHelper {
public:
virtual APIError init() = 0;
Expand Down Expand Up @@ -150,7 +152,6 @@ class APIPlaintextFrameHelper : public APIFrameHelper {

APIError try_read_frame_(ParsedFrame *frame);
APIError try_send_tx_buf_();
APIError write_frame_(const uint8_t *data, size_t len);
APIError write_raw_(const uint8_t *data, size_t len);

std::unique_ptr<socket::Socket> socket_;
Expand Down
2 changes: 1 addition & 1 deletion esphome/components/api/api_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ void APIServer::loop() {
[](const std::unique_ptr<APIConnection> &conn) { return !conn->remove_; });
// print disconnection messages
for (auto it = new_end; it != this->clients_.end(); ++it) {
ESP_LOGD(TAG, "Disconnecting %s", (*it)->client_info_.c_str());
ESP_LOGV(TAG, "Removing connection to %s", (*it)->client_info_.c_str());
}
// resize vector
this->clients_.erase(new_end, this->clients_.end());
Expand Down
Loading

0 comments on commit ed7983a

Please sign in to comment.