Skip to content

Commit

Permalink
Working permessage-deflate, minus some params
Browse files Browse the repository at this point in the history
  • Loading branch information
jcheng5 committed May 22, 2023
1 parent 6ab9394 commit e7f4c16
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 63 deletions.
52 changes: 31 additions & 21 deletions src/deflate.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "deflate.h"
#include <stdexcept>
#include <algorithm>
#include "utils.h"

namespace deflator {
Expand All @@ -8,32 +9,37 @@ typedef int(*flate_func)(z_stream* strm, int flush);

// Generic function for driving I/O loop for inflate/deflate
int flate(flate_func func, z_stream* strm, const char* data, size_t data_len, std::vector<char>& output) {
int error;

const size_t CHUNK_SIZE = 256;
unsigned char temp[CHUNK_SIZE];
memset(temp, 0, CHUNK_SIZE);

strm->next_in = reinterpret_cast<unsigned char*>(const_cast<char*>(data));
strm->avail_in = data_len;

size_t orig_output_size = output.size();

while (strm->avail_in) {
// Ensure enough room is allocated on the output to receive 1024 more bytes
const size_t CHUNK_SIZE = 1024;
output.resize(output.size() + CHUNK_SIZE);

strm->next_out = reinterpret_cast<unsigned char*>(&output[output.size() - CHUNK_SIZE]);
do {
strm->next_out = temp;
strm->avail_out = CHUNK_SIZE;
error = func(strm, Z_NO_FLUSH);
if (error != Z_OK && error != Z_BUF_ERROR) {
debug_log("flate failed", LOG_INFO);
return error;
}
// This might not copy any data; sometimes deflate() just populates buffers
output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out));
} while (strm->avail_out == 0 || strm->avail_in > 0);

int error = func(strm, Z_SYNC_FLUSH);
if (error != Z_OK && error != Z_STREAM_END) {
// std::cerr << strm->msg << "\n";
do {
strm->next_out = temp;
strm->avail_out = CHUNK_SIZE;
error = func(strm, Z_SYNC_FLUSH);
if (error != Z_OK && error != Z_BUF_ERROR) {
debug_log("flate failed", LOG_INFO);
return error;
}
}
size_t bytes_written = strm->total_out;
size_t bytes_allocated = output.size() - orig_output_size;
size_t bytes_to_erase = bytes_allocated - bytes_written;
if (bytes_to_erase > 0) {
output.erase(output.end() - bytes_to_erase, output.end());
}
output.insert(output.end(), (char*)temp, (char*)(temp + CHUNK_SIZE - strm->avail_out));
} while (strm->avail_out == 0);

return Z_OK;
}
Expand All @@ -46,7 +52,9 @@ Deflator::Deflator() {
Deflator::~Deflator() {
if (_state == DeflatorStateReady) {
int error = deflateEnd(&_stream);
if (error != Z_OK) {
if (error == Z_STREAM_ERROR) {
// deflateEnd can return other errors, but they're nothing to worry about.
// https://stackoverflow.com/a/19816633/139922
debug_log("deflateEnd failed", LOG_WARN);
}
}
Expand Down Expand Up @@ -90,7 +98,9 @@ Inflator::Inflator() {
Inflator::~Inflator() {
if (_state == DeflatorStateReady) {
int error = inflateEnd(&_stream);
if (error != Z_OK) {
if (error == Z_STREAM_ERROR) {
// inflateEnd can return other errors, but they're nothing to worry about.
// https://stackoverflow.com/a/19816633/139922
debug_log("inflateEnd failed", LOG_WARN);
}
}
Expand Down Expand Up @@ -121,7 +131,7 @@ int Inflator::init(DeflateMode mode, int windowBits) {

int Inflator::inflate(const char* data, size_t data_len, std::vector<char>& output) {
if (_state != DeflatorStateReady) {
throw std::runtime_error("Inflator.init() must be called before deflate()");
throw std::runtime_error("Inflator.init() must be called before inflate()");
}
return flate(::inflate, &_stream, data, data_len, output);
}
Expand Down
10 changes: 10 additions & 0 deletions src/websockets-base.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@
#include "constants.h"

struct WebSocketConnectionContext {
WebSocketConnectionContext() :
permessageDeflate(false),
clientMaxWindowBits(-1),
serverMaxWindowBits(-1),
clientNoContextTakeover(false),
serverNoContextTakeover(false) {
}

bool permessageDeflate;
int clientMaxWindowBits;
int serverMaxWindowBits;
bool clientNoContextTakeover;
bool serverNoContextTakeover;
};

class WebSocketProto {
Expand Down
42 changes: 4 additions & 38 deletions src/websockets-ietf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,15 @@
#include "sha1/sha1.h"
#include "base64/base64.hpp"

struct ExtensionInfo {
std::string extension;
std::map<std::string, std::string> params;
};

ExtensionInfo parseExtensionInfo(const std::string str) {
std::vector<std::string> parts = split(str, ";");
std::string extension = parts[0];
std::map<std::string, std::string> params;
for (size_t i = 1; i < parts.size(); i++) {
std::vector<std::string> param = split(parts[i], "=");
params[param[0]] = param.size() > 1 ? param[1] : "";
}
ExtensionInfo result;
result.extension = extension;
result.params = params;
return result;
}

std::vector<std::string> splitExtensionsHeader(const std::string& header) {
return split(header, ",");
}
#include "wse-permessage-deflate.h"

bool WebSocketProto_IETF::canHandle(const RequestHeaders& requestHeaders,
const char* pData, size_t len) const {

return requestHeaders.find("upgrade") != requestHeaders.end() &&
strcasecmp(requestHeaders.at("upgrade").c_str(), "websocket") == 0 &&
requestHeaders.find("sec-websocket-key") != requestHeaders.end();
requestHeaders.find("sec-websocket-key") != requestHeaders.end() &&
permessage_deflate::isValid(requestHeaders);
}

void WebSocketProto_IETF::handshake(const std::string& url,
Expand Down Expand Up @@ -65,21 +45,7 @@ void WebSocketProto_IETF::handshake(const std::string& url,
pResponseHeaders->push_back(
std::pair<std::string, std::string>("Sec-WebSocket-Accept", response));

auto swe = requestHeaders.find("sec-websocket-extensions");
if (swe != requestHeaders.end()) {
auto extensions = split(swe->second, ",");
std::vector<ExtensionInfo> extInfos;
std::transform(extensions.begin(), extensions.end(),
std::back_inserter(extInfos),
parseExtensionInfo);

for (auto pos = extInfos.begin(); pos != extInfos.end(); pos++) {
if (trim(pos->extension) == "permessage-deflate") {
pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", "permessage-deflate"));
pContext->permessageDeflate = true;
}
}
}
permessage_deflate::handshake(requestHeaders, pResponseHeaders, pContext);
}

bool WebSocketProto_IETF::isFin(uint8_t firstBit) const {
Expand Down
11 changes: 8 additions & 3 deletions src/websockets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ void WebSocketConnection::handshake(const std::string& url,
if (_context.permessageDeflate) {
int error;
// TODO: Handle errors
error = _inflator.init(deflator::DeflateModeRaw);
error = _inflator.init(deflator::DeflateModeRaw, _context.clientMaxWindowBits);
if (error != Z_OK) {
debug_log("Failed to init inflator", LOG_ERROR);
}
error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, 9);
error = _deflator.init(deflator::DeflateModeRaw, Z_DEFAULT_COMPRESSION, _context.serverMaxWindowBits);
if (error != Z_OK) {
debug_log("Failed to init deflator", LOG_ERROR);
}
Expand Down Expand Up @@ -415,7 +415,12 @@ void WebSocketConnection::onFrameComplete() {
_payload.push_back(0xFF);
_payload.push_back(0xFF);
std::vector<char> inflated(0);
_inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated);
int error = _inflator.inflate(safe_vec_addr(_payload), _payload.size(), inflated);
if (error != Z_OK) {
// TODO: Handle error
std::cerr << "Inflate failed with error " << error << "\n";
}

_pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(inflated), inflated.size());
} else {
_pCallbacks->onWSMessage(_header.opcode == Binary, safe_vec_addr(_payload), _payload.size());
Expand Down
4 changes: 3 additions & 1 deletion src/websockets.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ class WebSocketConnection : WSParserCallbacks, NoCopy {
: _pLoop(pLoop),
_connState(WS_OPEN),
_pCallbacks(callbacks),
_pParser(NULL) {
_pParser(NULL),
_deflator(),
_inflator() {
ASSERT_BACKGROUND_THREAD()
debug_log("WebSocketConnection::WebSocketConnection", LOG_DEBUG);

Expand Down
155 changes: 155 additions & 0 deletions src/wse-permessage-deflate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#include "wse-permessage-deflate.h"
#include <algorithm>
#include <map>
#include <string>
#include <vector>

struct ExtensionInfo {
std::string extension;
std::map<std::string, std::string> params;
};

ExtensionInfo parseExtensionInfo(const std::string str) {
std::vector<std::string> parts = split(str, ";");

std::transform(parts.cbegin(), parts.cend(), parts.begin(), trim);

std::string extension = parts[0];
std::map<std::string, std::string> params;
for (size_t i = 1; i < parts.size(); i++) {
std::vector<std::string> param = split(parts[i], "=");
params[trim(param[0])] = param.size() > 1 ? trim(param[1]) : "";
}
ExtensionInfo result;
result.extension = extension;
result.params = params;
return result;
}

bool parseWindowBits(const ExtensionInfo& extInfo, const std::string& param, bool* present, int* value) {
*present = false;
*value = 0;

auto match = extInfo.params.find(param);
if (match == extInfo.params.end()) {
// Not present--this is valid, so return true
return true;
}

*present = true;

if (match->second.size() == 0) {
// Param is present, but value is not
return true;
}

if (match->second.size() > 2) {
// Too many digits
return false;
}
*value = atoi(match->second.c_str());
if (*value < 8 || *value > 15) {
// Out of range
return false;
}
return true;
}

bool parsePermessageDeflate(const ExtensionInfo& extInfo, WebSocketConnectionContext* pContext) {
if (extInfo.extension != "permessage-deflate") {
return false;
}

pContext->permessageDeflate = true;
if (extInfo.params.find("server_no_context_takeover") != extInfo.params.end()) {
pContext->serverNoContextTakeover = true;
}
if (extInfo.params.find("client_no_context_takeover") != extInfo.params.end()) {
pContext->clientNoContextTakeover = true;
}

bool hasServerMaxWindowBits;
bool hasClientMaxWindowBits;

if (!parseWindowBits(extInfo, "server_max_window_bits", &hasServerMaxWindowBits, &pContext->serverMaxWindowBits)) {
return false;
}
if (hasServerMaxWindowBits && pContext->serverMaxWindowBits == 0) {
// If server_max_window_bits is present, the value is required
return false;
}

if (!parseWindowBits(extInfo, "client_max_window_bits", &hasClientMaxWindowBits, &pContext->clientMaxWindowBits)) {
return false;
}

// Set defaults
if (pContext->serverMaxWindowBits <= 0) {
pContext->serverMaxWindowBits = 15;
}
if (pContext->clientMaxWindowBits <= 0) {
pContext->clientMaxWindowBits = 15;
}

return true;
}

bool handle(const RequestHeaders& requestHeaders,
ResponseHeaders* pResponseHeaders,
WebSocketConnectionContext* pContext) {

auto swe = requestHeaders.find("sec-websocket-extensions");
if (swe != requestHeaders.end()) {
auto extensions = split(swe->second, ",");
std::vector<ExtensionInfo> extInfos;
std::transform(extensions.begin(), extensions.end(),
std::back_inserter(extInfos),
parseExtensionInfo);

for (auto &extInfo : extInfos) {
if (trim(extInfo.extension) == "permessage-deflate") {
if (!parsePermessageDeflate(extInfo, pContext)) {
return false;
}
}
}
}

if (pResponseHeaders && pContext->permessageDeflate) {
std::string params;
if (pContext->clientNoContextTakeover) {
params.append("; client_no_context_takeover");
}
if (pContext->serverNoContextTakeover) {
params.append("; server_no_context_takeover");
}
if (pContext->serverMaxWindowBits != 0) {
params.append("; server_max_window_bits=");
params.append(std::to_string(pContext->serverMaxWindowBits));
}
if (pContext->clientMaxWindowBits != 0) {
params.append("; client_max_window_bits=");
params.append(std::to_string(pContext->clientMaxWindowBits));
}
std::string exts = "permessage-deflate" + params;
pResponseHeaders->push_back(std::make_pair("Sec-WebSocket-Extensions", exts));
}

return true;
}

namespace permessage_deflate {

bool isValid(const RequestHeaders& requestHeaders) {
WebSocketConnectionContext context;
return handle(requestHeaders, NULL, &context);
}

void handshake(const RequestHeaders& requestHeaders,
ResponseHeaders* pResponseHeaders,
WebSocketConnectionContext* pContext) {

handle(requestHeaders, pResponseHeaders, pContext);
}

} // namespace permessage_deflate
17 changes: 17 additions & 0 deletions src/wse-permessage-deflate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef WSEPERMESSAGEDEFLATE_H
#define WSEPERMESSAGEDEFLATE_H

#include "constants.h"
#include "websockets-base.h"

namespace permessage_deflate {

bool isValid(const RequestHeaders& requestHeaders);

void handshake(const RequestHeaders& requestHeaders,
ResponseHeaders* pResponseHeaders,
WebSocketConnectionContext* pContext);

} // namespace permessage_deflate

#endif

0 comments on commit e7f4c16

Please sign in to comment.