Skip to content

Commit

Permalink
feat(core): add zero copy buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Nov 23, 2024
1 parent c710435 commit f02348e
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 39 deletions.
1 change: 1 addition & 0 deletions src/core/lib.zig
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub const Job = @import("job.zig").Job;
pub const ZeroCopyBuffer = @import("zc_buffer.zig").ZeroCopyBuffer;
pub const Pseudoslice = @import("pseudoslice.zig").Pseudoslice;
137 changes: 137 additions & 0 deletions src/core/zc_buffer.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
const std = @import("std");
const assert = std.debug.assert;

pub const ZeroCopyBuffer = struct {
allocator: std.mem.Allocator,
ptr: [*]u8,
len: usize,
capacity: usize,

pub fn init(allocator: std.mem.Allocator, capacity: usize) !ZeroCopyBuffer {
const slice = try allocator.alloc(u8, capacity);
return .{
.allocator = allocator,
.ptr = slice.ptr,
.len = 0,
.capacity = capacity,
};
}

pub fn deinit(self: *ZeroCopyBuffer) void {
self.allocator.free(self.ptr[0..self.capacity]);
}

pub fn as_slice(self: *ZeroCopyBuffer) []u8 {
return self.ptr[0..self.len];
}

/// This returns a slice that you can write into for zero-copy uses.
/// This is mostly used when we are passing a buffer to I/O then acting on it.
///
/// The write area that is returned is ONLY valid until the next call of get_write_area
/// or mark_written.
pub fn get_write_area(self: *ZeroCopyBuffer, size: usize) ![]u8 {
const available_space = self.capacity - self.len;
if (available_space >= size) {
return self.ptr[self.len .. self.len + size];
} else {
const old_slice = self.ptr[0..self.capacity];
const new_size = try std.math.ceilPowerOfTwo(usize, self.capacity + size);

if (self.allocator.resize(self.ptr[0..self.capacity], new_size)) {
self.capacity = new_size;
} else {
const new_slice = try self.allocator.alloc(u8, new_size);
@memcpy(new_slice[0..self.len], self.ptr[0..self.len]);
self.allocator.free(old_slice);

self.ptr = new_slice.ptr;
self.capacity = new_slice.len;
}

assert(self.capacity - self.len >= size);
return self.ptr[self.len .. self.len + size];
}
}

pub fn get_write_area_assume_space(self: *ZeroCopyBuffer, size: usize) []u8 {
assert(self.capacity - self.len >= size);
return self.ptr[self.len .. self.len + size];
}

pub fn mark_written(self: *ZeroCopyBuffer, length: usize) void {
assert(self.len + length <= self.capacity);
self.len += length;
}

pub fn shrink_retaining_capacity(self: *ZeroCopyBuffer, new_size: usize) void {
assert(new_size <= self.len);
self.len = new_size;
}

pub fn clear_retaining_capacity(self: *ZeroCopyBuffer) void {
self.len = 0;
}

pub fn clear_and_free(self: *ZeroCopyBuffer) void {
self.allocator.free(self.ptr[0..self.capacity]);
self.len = 0;
self.capacity = 0;
}
};

const testing = std.testing;

test "ZeroCopyBuffer: First" {
const garbage: []const u8 = &[_]u8{212} ** 128;

var zc = try ZeroCopyBuffer.init(testing.allocator, 512);
defer zc.deinit();

const write_area = try zc.get_write_area(garbage.len);
@memcpy(write_area, garbage);
zc.mark_written(write_area.len);

try testing.expectEqualSlices(u8, garbage[0..], zc.as_slice()[0..write_area.len]);
}

test "ZeroCopyBuffer: Growth" {
var zc = try ZeroCopyBuffer.init(testing.allocator, 16);
defer zc.deinit();

const large_data = &[_]u8{1} ** 32;
const write_area = try zc.get_write_area(large_data.len);
@memcpy(write_area, large_data);
zc.mark_written(write_area.len);

try testing.expect(zc.capacity >= 32);
try testing.expectEqualSlices(u8, large_data, zc.as_slice());
}

test "ZeroCopyBuffer: Multiple Writes" {
var zc = try ZeroCopyBuffer.init(testing.allocator, 64);
defer zc.deinit();

const data1 = "Hello, ";
const data2 = "World!";

const area1 = try zc.get_write_area(data1.len);
@memcpy(area1, data1);
zc.mark_written(area1.len);

const area2 = try zc.get_write_area(data2.len);
@memcpy(area2, data2);
zc.mark_written(area2.len);

try testing.expectEqualSlices(u8, "Hello, World!", zc.as_slice());
}

test "ZeroCopyBuffer: Zero Size Write" {
var zc = try ZeroCopyBuffer.init(testing.allocator, 8);
defer zc.deinit();

const area = try zc.get_write_area(0);
try testing.expect(area.len == 0);
zc.mark_written(0);
try testing.expect(zc.len == 0);
}
13 changes: 7 additions & 6 deletions src/http/provision.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const std = @import("std");

const Job = @import("../core/job.zig").Job;
const ZeroCopyBuffer = @import("../core/zc_buffer.zig").ZeroCopyBuffer;
const Capture = @import("routing_trie.zig").Capture;
const QueryMap = @import("routing_trie.zig").QueryMap;
const Request = @import("request.zig").Request;
Expand All @@ -13,7 +14,7 @@ pub const Provision = struct {
job: Job,
socket: std.posix.socket_t,
buffer: []u8,
recv_buffer: std.ArrayList(u8),
recv_buffer: ZeroCopyBuffer,
arena: std.heap.ArenaAllocator,
captures: []Capture,
queries: QueryMap,
Expand All @@ -31,12 +32,12 @@ pub const Provision = struct {
for (provisions) |*provision| {
provision.job = .empty;
provision.socket = undefined;
// Create Buffer
provision.buffer = ctx.allocator.alloc(u8, config.socket_buffer_bytes) catch {
@panic("attempting to statically allocate more memory than available. (Socket Buffer)");
};
// Create Recv Buffer
provision.recv_buffer = std.ArrayList(u8).init(ctx.allocator);
provision.recv_buffer = ZeroCopyBuffer.init(ctx.allocator, config.socket_buffer_bytes) catch {
@panic("attempting to statically allocate more memory than available. (ZeroCopyBuffer)");
};
// Create Buffer
provision.buffer = provision.recv_buffer.get_write_area_assume_space(config.socket_buffer_bytes);
// Create the Context Arena
provision.arena = std.heap.ArenaAllocator.init(ctx.allocator);

Expand Down
1 change: 1 addition & 0 deletions src/http/router.zig
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ pub fn Router(comptime Server: type) type {
const extension_start = std.mem.lastIndexOfScalar(u8, search_path, '.');
const mime: Mime = blk: {
if (extension_start) |start| {
if (search_path.len - start >= 0) break :blk Mime.BIN;
break :blk Mime.from_extension(search_path[start + 1 ..]);
} else {
break :blk Mime.BIN;
Expand Down
74 changes: 41 additions & 33 deletions src/http/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,10 @@ pub fn Server(comptime security: Security) type {
_ = provision.arena.reset(.{ .retain_with_limit = config.connection_arena_bytes_retain });
provision.response.clear();

if (provision.recv_buffer.items.len > config.list_recv_bytes_retain) {
provision.recv_buffer.shrinkRetainingCapacity(config.list_recv_bytes_retain);
} else {
provision.recv_buffer.clearRetainingCapacity();
if (provision.recv_buffer.len > config.list_recv_bytes_retain) {
provision.recv_buffer.shrink_retaining_capacity(config.list_recv_bytes_retain);
}
provision.recv_buffer.clear_retaining_capacity();

pool.release(provision.index);

Expand Down Expand Up @@ -352,25 +351,26 @@ pub fn Server(comptime security: Security) type {
recv_job.count += recv_count;
const pre_recv_buffer = provision.buffer[0..recv_count];

const recv_buffer = blk: {
switch (comptime security) {
.tls => |_| {
const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType);
const tls_ptr: *TLSType = &tls_slice[provision.index];
assert(tls_ptr.* != null);

break :blk tls_ptr.*.?.decrypt(pre_recv_buffer) catch |e| {
log.err("{d} - decrypt failed: {any}", .{ provision.index, e });
provision.job = .close;
try rt.net.close(provision, close_task, provision.socket);
return error.TLSDecryptFailed;
};
},
.plain => break :blk pre_recv_buffer,
}
};
if (comptime security == .tls) {
const tls_slice = rt.storage.get("__zzz_tls_slice", []TLSType);
const tls_ptr: *TLSType = &tls_slice[provision.index];
assert(tls_ptr.* != null);

const status = try on_recv(recv_buffer, rt, provision, router, config);
const decrypted = tls_ptr.*.?.decrypt(pre_recv_buffer) catch |e| {
log.err("{d} - decrypt failed: {any}", .{ provision.index, e });
provision.job = .close;
try rt.net.close(provision, close_task, provision.socket);
return error.TLSDecryptFailed;
};

const area = try provision.recv_buffer.get_write_area(decrypted.len);
std.mem.copyForwards(u8, area, decrypted);
provision.recv_buffer.mark_written(decrypted.len);
} else {
provision.recv_buffer.mark_written(recv_count);
}

const status = try on_recv(recv_count, rt, provision, router, config);

switch (status) {
.spawned => return,
Expand Down Expand Up @@ -535,7 +535,8 @@ pub fn Server(comptime security: Security) type {
_ = provision.arena.reset(.{
.retain_with_limit = config.connection_arena_bytes_retain,
});
provision.recv_buffer.clearRetainingCapacity();

provision.recv_buffer.clear_retaining_capacity();
provision.job = .{ .recv = .{ .count = 0 } };

try rt.net.recv(
Expand Down Expand Up @@ -845,7 +846,8 @@ pub fn Server(comptime security: Security) type {
}

inline fn on_recv(
buffer: []const u8,
// How much we just received
recv_count: usize,
rt: *Runtime,
provision: *Provision,
router: *const Router,
Expand All @@ -866,9 +868,15 @@ pub fn Server(comptime security: Security) type {

switch (stage) {
.header => {
const start = provision.recv_buffer.items.len -| 4;
try provision.recv_buffer.appendSlice(buffer);
const header_ends = std.mem.lastIndexOf(u8, provision.recv_buffer.items[start..], "\r\n\r\n");
const starting_length = job.count - recv_count;
const start = starting_length -| 4;

// Technically, we no longer need to append.
// try provision.recv_buffer.appendSlice(buffer);
provision.buffer = try provision.recv_buffer.get_write_area(config.socket_buffer_bytes);

// need to specify end
const header_ends = std.mem.lastIndexOf(u8, provision.recv_buffer.as_slice()[start..], "\r\n\r\n");

// Basically, this means we haven't finished processing the header.
if (header_ends == null) {
Expand All @@ -879,7 +887,7 @@ pub fn Server(comptime security: Security) type {
log.debug("{d} - parsing header", .{provision.index});
// The +4 is to account for the slice we match.
const header_end: u32 = @intCast(header_ends.? + 4);
provision.request.parse_headers(provision.recv_buffer.items[0..header_end], .{
provision.request.parse_headers(provision.recv_buffer.as_slice()[0..header_end], .{
.size_request_max = config.request_bytes_max,
.size_request_uri_max = config.request_uri_bytes_max,
}) catch |e| {
Expand Down Expand Up @@ -965,14 +973,14 @@ pub fn Server(comptime security: Security) type {
break :blk try std.fmt.parseInt(u32, length_string, 10);
};

if (header_end < provision.recv_buffer.items.len) {
const difference = provision.recv_buffer.items.len - header_end;
if (header_end < provision.recv_buffer.len) {
const difference = provision.recv_buffer.len - header_end;
if (difference == content_length) {
// Whole Body
log.debug("{d} - got whole body with header", .{provision.index});
const body_end = header_end + difference;
provision.request.set(.{
.body = provision.recv_buffer.items[header_end..body_end],
.body = provision.recv_buffer.as_slice()[header_end..body_end],
});
return try route_and_respond(rt, provision, router);
} else {
Expand All @@ -981,7 +989,7 @@ pub fn Server(comptime security: Security) type {
stage = .{ .body = header_end };
return .recv;
}
} else if (header_end == provision.recv_buffer.items.len) {
} else if (header_end == provision.recv_buffer.len) {
// Body of length 0 probably or only got header.
if (content_length == 0) {
log.debug("{d} - got body of length 0", .{provision.index});
Expand Down Expand Up @@ -1030,7 +1038,7 @@ pub fn Server(comptime security: Security) type {

if (job.count >= request_length) {
provision.request.set(.{
.body = provision.recv_buffer.items[header_end..request_length],
.body = provision.recv_buffer.as_slice()[header_end..request_length],
});
return try route_and_respond(rt, provision, router);
} else {
Expand Down

0 comments on commit f02348e

Please sign in to comment.