diff --git a/libraries/HTTPUpdateServer/src/HTTPUpdateServer.h b/libraries/HTTPUpdateServer/src/HTTPUpdateServer.h index bb32bc03fdb..65d8cbaa783 100644 --- a/libraries/HTTPUpdateServer/src/HTTPUpdateServer.h +++ b/libraries/HTTPUpdateServer/src/HTTPUpdateServer.h @@ -27,6 +27,7 @@ static const char serverIndex[] PROGMEM = )"; static const char successResponse[] PROGMEM = "Update Success! Rebooting..."; +static const char *csrfHeaders[2] = {"Origin", "Host"}; class HTTPUpdateServer { public: @@ -56,6 +57,9 @@ class HTTPUpdateServer { _username = username; _password = password; + // collect headers for CSRF verification + _server->collectHeaders(csrfHeaders, 2); + // handler for the /update form page _server->on(path.c_str(), HTTP_GET, [&]() { if (_username != emptyString && _password != emptyString && !_server->authenticate(_username.c_str(), _password.c_str())) { @@ -69,6 +73,10 @@ class HTTPUpdateServer { path.c_str(), HTTP_POST, [&]() { if (!_authenticated) { + if (_username == emptyString || _password == emptyString) { + _server->send(200, F("text/html"), String(F("Update error: Wrong origin received!"))); + return; + } return _server->requestAuthentication(); } if (Update.hasError()) { @@ -100,6 +108,17 @@ class HTTPUpdateServer { return; } + String origin = _server->header(String(csrfHeaders[0])); + String host = _server->header(String(csrfHeaders[1])); + String expectedOrigin = String("http://") + host; + if (origin != expectedOrigin) { + if (_serial_output) { + Serial.printf("Wrong origin received! Expected: %s, Received: %s\n", expectedOrigin.c_str(), origin.c_str()); + } + _authenticated = false; + return; + } + if (_serial_output) { Serial.printf("Update: %s\n", upload.filename.c_str()); } diff --git a/libraries/Update/examples/OTAWebUpdater/OTAWebUpdater.ino b/libraries/Update/examples/OTAWebUpdater/OTAWebUpdater.ino index 7059bef4496..39d6cbce4af 100644 --- a/libraries/Update/examples/OTAWebUpdater/OTAWebUpdater.ino +++ b/libraries/Update/examples/OTAWebUpdater/OTAWebUpdater.ino @@ -8,10 +8,17 @@ #define SSID_FORMAT "ESP32-%06lX" // 12 chars total //#define PASSWORD "test123456" // generate if remarked +// Set the username and password for firmware upload +const char *authUser = "........"; +const char *authPass = "........"; + WebServer server(80); Ticker tkSecond; uint8_t otaDone = 0; +const char *csrfHeaders[2] = {"Origin", "Host"}; +static bool authenticated = false; + const char *alphanum = "0123456789!@#$%^&*abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; String generatePass(uint8_t str_len) { String buff; @@ -38,6 +45,9 @@ void apMode() { } void handleUpdateEnd() { + if (!authenticated) { + return server.requestAuthentication(); + } server.sendHeader("Connection", "close"); if (Update.hasError()) { server.send(502, "text/plain", Update.errorString()); @@ -45,6 +55,7 @@ void handleUpdateEnd() { server.sendHeader("Refresh", "10"); server.sendHeader("Location", "/"); server.send(307); + delay(500); ESP.restart(); } } @@ -56,18 +67,34 @@ void handleUpdate() { } HTTPUpload &upload = server.upload(); if (upload.status == UPLOAD_FILE_START) { + authenticated = server.authenticate(authUser, authPass); + if (!authenticated) { + Serial.println("Authentication fail!"); + otaDone = 0; + return; + } + String origin = server.header(String(csrfHeaders[0])); + String host = server.header(String(csrfHeaders[1])); + String expectedOrigin = String("http://") + host; + if (origin != expectedOrigin) { + Serial.printf("Wrong origin received! Expected: %s, Received: %s\n", expectedOrigin.c_str(), origin.c_str()); + authenticated = false; + otaDone = 0; + return; + } + Serial.printf("Receiving Update: %s, Size: %d\n", upload.filename.c_str(), fsize); if (!Update.begin(fsize)) { otaDone = 0; Update.printError(Serial); } - } else if (upload.status == UPLOAD_FILE_WRITE) { + } else if (authenticated && upload.status == UPLOAD_FILE_WRITE) { if (Update.write(upload.buf, upload.currentSize) != upload.currentSize) { Update.printError(Serial); } else { otaDone = 100 * Update.progress() / Update.size(); } - } else if (upload.status == UPLOAD_FILE_END) { + } else if (authenticated && upload.status == UPLOAD_FILE_END) { if (Update.end(true)) { Serial.printf("Update Success: %u bytes\nRebooting...\n", upload.totalSize); } else { @@ -78,6 +105,7 @@ void handleUpdate() { } void webServerInit() { + server.collectHeaders(csrfHeaders, 2); server.on( "/update", HTTP_POST, []() { @@ -92,6 +120,9 @@ void webServerInit() { server.send_P(200, "image/x-icon", favicon_ico_gz, favicon_ico_gz_len); }); server.onNotFound([]() { + if (!server.authenticate(authUser, authPass)) { + return server.requestAuthentication(); + } server.send(200, "text/html", indexHtml); }); server.begin(); diff --git a/libraries/WebServer/examples/WebUpdate/WebUpdate.ino b/libraries/WebServer/examples/WebUpdate/WebUpdate.ino index 10ddb5e7b64..9e45de7d985 100644 --- a/libraries/WebServer/examples/WebUpdate/WebUpdate.ino +++ b/libraries/WebServer/examples/WebUpdate/WebUpdate.ino @@ -12,10 +12,17 @@ const char *host = "esp32-webupdate"; const char *ssid = "........"; const char *password = "........"; +// Set the username and password for firmware upload +const char *authUser = "........"; +const char *authPass = "........"; + WebServer server(80); const char *serverIndex = "
"; +const char *csrfHeaders[2] = {"Origin", "Host"}; +static bool authenticated = false; + void setup(void) { Serial.begin(115200); Serial.println(); @@ -24,37 +31,63 @@ void setup(void) { WiFi.begin(ssid, password); if (WiFi.waitForConnectResult() == WL_CONNECTED) { MDNS.begin(host); + server.collectHeaders(csrfHeaders, 2); server.on("/", HTTP_GET, []() { + if (!server.authenticate(authUser, authPass)) { + return server.requestAuthentication(); + } server.sendHeader("Connection", "close"); server.send(200, "text/html", serverIndex); }); server.on( "/update", HTTP_POST, []() { + if (!authenticated) { + return server.requestAuthentication(); + } server.sendHeader("Connection", "close"); - server.send(200, "text/plain", (Update.hasError()) ? "FAIL" : "OK"); - ESP.restart(); + if (Update.hasError()) { + server.send(200, "text/plain", "FAIL"); + } else { + server.send(200, "text/plain", "Success! Rebooting..."); + delay(500); + ESP.restart(); + } }, []() { HTTPUpload &upload = server.upload(); if (upload.status == UPLOAD_FILE_START) { Serial.setDebugOutput(true); + authenticated = server.authenticate(authUser, authPass); + if (!authenticated) { + Serial.println("Authentication fail!"); + return; + } + String origin = server.header(String(csrfHeaders[0])); + String host = server.header(String(csrfHeaders[1])); + String expectedOrigin = String("http://") + host; + if (origin != expectedOrigin) { + Serial.printf("Wrong origin received! Expected: %s, Received: %s\n", expectedOrigin.c_str(), origin.c_str()); + authenticated = false; + return; + } + Serial.printf("Update: %s\n", upload.filename.c_str()); if (!Update.begin()) { //start with max available size Update.printError(Serial); } - } else if (upload.status == UPLOAD_FILE_WRITE) { + } else if (authenticated && upload.status == UPLOAD_FILE_WRITE) { if (Update.write(upload.buf, upload.currentSize) != upload.currentSize) { Update.printError(Serial); } - } else if (upload.status == UPLOAD_FILE_END) { + } else if (authenticated && upload.status == UPLOAD_FILE_END) { if (Update.end(true)) { //true to set the size to the current progress Serial.printf("Update Success: %u\nRebooting...\n", upload.totalSize); } else { Update.printError(Serial); } Serial.setDebugOutput(false); - } else { + } else if (authenticated) { Serial.printf("Update Failed Unexpectedly (likely broken connection): status=%d\n", upload.status); } }