Skip to content

Commit

Permalink
Add optional user defined header writer (yhirose#1683)
Browse files Browse the repository at this point in the history
* Add optional user defined header writer

* Fix errors and add test
  • Loading branch information
PabloMK7 authored Oct 1, 2023
1 parent c029597 commit a609330
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
36 changes: 34 additions & 2 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ class RegexMatcher : public MatcherBase {
std::regex regex_;
};

ssize_t write_headers(Stream &strm, const Headers &headers);

} // namespace detail

class Server {
Expand Down Expand Up @@ -800,6 +802,8 @@ class Server {
Server &set_socket_options(SocketOptions socket_options);

Server &set_default_headers(Headers headers);
Server &
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);

Server &set_keep_alive_max_count(size_t count);
Server &set_keep_alive_timeout(time_t sec);
Expand Down Expand Up @@ -934,6 +938,8 @@ class Server {
SocketOptions socket_options_ = default_socket_options;

Headers default_headers_;
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
detail::write_headers;
};

enum class Error {
Expand Down Expand Up @@ -1164,6 +1170,9 @@ class ClientImpl {

void set_default_headers(Headers headers);

void
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);

void set_address_family(int family);
void set_tcp_nodelay(bool on);
void set_socket_options(SocketOptions socket_options);
Expand Down Expand Up @@ -1273,6 +1282,10 @@ class ClientImpl {
// Default headers
Headers default_headers_;

// Header writer
std::function<ssize_t(Stream &, Headers &)> header_writer_ =
detail::write_headers;

// Settings
std::string client_cert_path_;
std::string client_key_path_;
Expand Down Expand Up @@ -1539,6 +1552,9 @@ class Client {

void set_default_headers(Headers headers);

void
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);

void set_address_family(int family);
void set_tcp_nodelay(bool on);
void set_socket_options(SocketOptions socket_options);
Expand Down Expand Up @@ -5672,6 +5688,12 @@ inline Server &Server::set_default_headers(Headers headers) {
return *this;
}

inline Server &Server::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
header_writer_ = writer;
return *this;
}

inline Server &Server::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
return *this;
Expand Down Expand Up @@ -5866,7 +5888,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
return false;
}

if (!detail::write_headers(bstrm, res.headers)) { return false; }
if (!header_writer_(bstrm, res.headers)) { return false; }

// Flush buffer
auto &data = bstrm.get_buffer();
Expand Down Expand Up @@ -7105,7 +7127,7 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path;
bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());

detail::write_headers(bstrm, req.headers);
header_writer_(bstrm, req.headers);

// Flush buffer
auto &data = bstrm.get_buffer();
Expand Down Expand Up @@ -7916,6 +7938,11 @@ inline void ClientImpl::set_default_headers(Headers headers) {
default_headers_ = std::move(headers);
}

inline void ClientImpl::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
header_writer_ = writer;
}

inline void ClientImpl::set_address_family(int family) {
address_family_ = family;
}
Expand Down Expand Up @@ -9110,6 +9137,11 @@ inline void Client::set_default_headers(Headers headers) {
cli_->set_default_headers(std::move(headers));
}

inline void Client::set_header_writer(
std::function<ssize_t(Stream &, Headers &)> const &writer) {
cli_->set_header_writer(writer);
}

inline void Client::set_address_family(int family) {
cli_->set_address_family(family);
}
Expand Down
40 changes: 40 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,46 @@ TEST(URLFragmentTest, WithFragment) {
}
}

TEST(HeaderWriter, SetHeaderWriter) {
Server svr;

svr.set_header_writer([](Stream &strm, Headers &hdrs) {
hdrs.emplace("CustomServerHeader", "CustomServerValue");
return detail::write_headers(strm, hdrs);
});
svr.Get("/hi", [](const Request &req, Response &res) {
auto it = req.headers.find("CustomClientHeader");
EXPECT_TRUE(it != req.headers.end());
EXPECT_EQ(it->second, "CustomClientValue");
res.set_content("Hello World!\n", "text/plain");
});

auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
auto se = detail::scope_exit([&] {
svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
});

std::this_thread::sleep_for(std::chrono::seconds(1));

{
Client cli(HOST, PORT);
cli.set_header_writer([](Stream &strm, Headers &hdrs) {
hdrs.emplace("CustomClientHeader", "CustomClientValue");
return detail::write_headers(strm, hdrs);
});

auto res = cli.Get("/hi");
EXPECT_TRUE(res);
EXPECT_EQ(200, res->status);

auto it = res->headers.find("CustomServerHeader");
EXPECT_TRUE(it != res->headers.end());
EXPECT_EQ(it->second, "CustomServerValue");
}
}

class ServerTest : public ::testing::Test {
protected:
ServerTest()
Expand Down

0 comments on commit a609330

Please sign in to comment.