Skip to content

Commit

Permalink
fix(server): clear requests to prevent connection stalling
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Nov 26, 2024
1 parent 28b3b37 commit 39bb96a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 34 deletions.
11 changes: 11 additions & 0 deletions examples/basic/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ pub fn main() !void {
}
}.handler_fn));

try router.serve_route("/echo", Route.init().post({}, struct {
fn handler_fn(ctx: *Context, _: void) !void {
const body = try ctx.allocator.dupe(u8, ctx.request.body);
try ctx.respond(.{
.status = .OK,
.mime = http.Mime.HTML,
.body = body[0..],
});
}
}.handler_fn));

router.serve_not_found(Route.init().get({}, struct {
fn handler_fn(ctx: *Context, _: void) !void {
try ctx.respond(.{
Expand Down
14 changes: 14 additions & 0 deletions src/core/zc_buffer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@ pub const ZeroCopyBuffer = struct {
return self.ptr[0..self.len];
}

const SubsliceOptions = struct {
start: ?usize = null,
end: ?usize = null,
};

pub fn subslice(self: *ZeroCopyBuffer, options: SubsliceOptions) []u8 {
const start: usize = options.start orelse 0;
const end: usize = options.end orelse self.len;
assert(start <= end);
assert(end <= self.len);

return self.ptr[start..end];
}

/// This returns a slice that you can write into for zero-copy uses.
/// This is mostly used when we are passing a buffer to I/O then acting on it.
///
Expand Down
9 changes: 8 additions & 1 deletion src/http/request.zig
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ pub const Request = struct {
self.headers.deinit(self.allocator);
}

pub fn clear(self: *Request) void {
self.method = undefined;
self.uri = undefined;
self.body = undefined;
self.headers.clearRetainingCapacity();
}

const RequestParseOptions = struct {
request_bytes_max: u32,
request_uri_bytes_max: u32,
};

pub fn parse_headers(self: *Request, bytes: []const u8, options: RequestParseOptions) HTTPError!void {
self.headers.clearRetainingCapacity();
self.clear();
var total_size: u32 = 0;
var lines = std.mem.tokenizeAny(u8, bytes, "\r\n");

Expand Down
87 changes: 54 additions & 33 deletions src/http/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ pub fn Server(comptime security: Security) type {
provision.socket = Cross.socket.INVALID_SOCKET;
provision.job = .empty;
_ = provision.arena.reset(.{ .retain_with_limit = config.connection_arena_bytes_retain });

provision.request.clear();
provision.response.clear();

if (provision.recv_buffer.len > config.list_recv_bytes_retain) {
Expand Down Expand Up @@ -350,30 +352,39 @@ pub fn Server(comptime security: Security) type {

log.debug("{d} - recv triggered", .{provision.index});

// recv_count is how many bytes we have read off the socket
const recv_count: usize = @intCast(length);
recv_job.count += recv_count;
const pre_recv_buffer = provision.buffer[0..recv_count];

if (comptime security == .tls) {
const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType);
const tls_ptr: *TLSType = &tls_slice[provision.index];
assert(tls_ptr.* != null);
// this is how many http bytes we have received
const http_bytes_count: usize = blk: {
if (comptime security == .tls) {
const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType);
const tls_ptr: *TLSType = &tls_slice[provision.index];
assert(tls_ptr.* != null);

const decrypted = tls_ptr.*.?.decrypt(pre_recv_buffer) catch |e| {
log.err("{d} - decrypt failed: {any}", .{ provision.index, e });
provision.job = .close;
try rt.net.close(provision, close_task, provision.socket);
return error.TLSDecryptFailed;
};
const decrypted = tls_ptr.*.?.decrypt(provision.buffer[0..recv_count]) catch |e| {
log.err("{d} - decrypt failed: {any}", .{ provision.index, e });
provision.job = .close;
try rt.net.close(provision, close_task, provision.socket);
return error.TLSDecryptFailed;
};

const area = try provision.recv_buffer.get_write_area(decrypted.len);
std.mem.copyForwards(u8, area, decrypted);
provision.recv_buffer.mark_written(decrypted.len);
} else {
provision.recv_buffer.mark_written(recv_count);
}
// since we haven't marked the write yet, we can get a new write area
// that is directly adjacent to the last write block.
const area = try provision.recv_buffer.get_write_area(decrypted.len);
std.mem.copyForwards(u8, area, decrypted);
break :blk decrypted.len;
} else {
break :blk recv_count;
}
};

provision.recv_buffer.mark_written(http_bytes_count);
provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes);
recv_job.count += http_bytes_count;

const status = try on_recv(recv_count, rt, provision, router, config);
const status = try on_recv(http_bytes_count, rt, provision, router, config);
assert(provision.buffer.len == config.socket_buffer_bytes);

switch (status) {
.spawned => return,
Expand Down Expand Up @@ -839,7 +850,7 @@ pub fn Server(comptime security: Security) type {
return try raw_respond(p);
}

inline fn on_recv(
fn on_recv(
// How much we just received
recv_count: usize,
rt: *Runtime,
Expand All @@ -862,15 +873,15 @@ pub fn Server(comptime security: Security) type {

switch (stage) {
.header => {
const starting_length = job.count - recv_count;
// this should never underflow if things are working correctly.
const starting_length = provision.recv_buffer.len - recv_count;
const start = starting_length -| 4;

// Technically, we no longer need to append.
// try provision.recv_buffer.appendSlice(buffer);
provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes);

// need to specify end
const header_ends = std.mem.lastIndexOf(u8, provision.recv_buffer.as_slice()[start..], "\r\n\r\n");
const header_ends = std.mem.lastIndexOf(
u8,
provision.recv_buffer.subslice(.{ .start = start }),
"\r\n\r\n",
);

// Basically, this means we haven't finished processing the header.
if (header_ends == null) {
Expand All @@ -883,10 +894,13 @@ pub fn Server(comptime security: Security) type {
// starting at the index of start.
// The +4 is to account for the slice we match.
const header_end: usize = header_ends.? + start + 4;
provision.request.parse_headers(provision.recv_buffer.as_slice()[0..header_end], .{
.request_bytes_max = config.request_bytes_max,
.request_uri_bytes_max = config.request_uri_bytes_max,
}) catch |e| {
provision.request.parse_headers(
provision.recv_buffer.subslice(.{ .end = header_end }),
.{
.request_bytes_max = config.request_bytes_max,
.request_uri_bytes_max = config.request_uri_bytes_max,
},
) catch |e| {
switch (e) {
HTTPError.ContentTooLarge => {
provision.response.set(.{
Expand Down Expand Up @@ -976,7 +990,10 @@ pub fn Server(comptime security: Security) type {
log.debug("{d} - got whole body with header", .{provision.index});
const body_end = header_end + difference;
provision.request.set(.{
.body = provision.recv_buffer.as_slice()[header_end..body_end],
.body = provision.recv_buffer.subslice(.{
.start = header_end,
.end = body_end,
}),
});
return try route_and_respond(rt, provision, router);
} else {
Expand Down Expand Up @@ -1020,6 +1037,7 @@ pub fn Server(comptime security: Security) type {
break :blk try std.fmt.parseInt(u32, length_string, 10);
};

// We factor in the length of the headers.
const request_length = header_end + content_length;

// If this body will be too long, abort early.
Expand All @@ -1034,7 +1052,10 @@ pub fn Server(comptime security: Security) type {

if (job.count >= request_length) {
provision.request.set(.{
.body = provision.recv_buffer.as_slice()[header_end..request_length],
.body = provision.recv_buffer.subslice(.{
.start = header_end,
.end = request_length,
}),
});
return try route_and_respond(rt, provision, router);
} else {
Expand Down

0 comments on commit 39bb96a

Please sign in to comment.