Skip to content

Commit

Permalink
add option for dynamic completion reap batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Sep 23, 2024
1 parent d634161 commit 8442fc7
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 77 deletions.
60 changes: 44 additions & 16 deletions src/async/io_uring.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@ const log = std.log.scoped(.@"async/io_uring");

pub const AsyncIoUring = struct {
runner: *anyopaque,
completions: [256]Completion,

pub fn init(uring: *std.os.linux.IoUring) !AsyncIoUring {
return AsyncIoUring{
.runner = uring,
.completions = [_]Completion{undefined} ** 256,
};
return AsyncIoUring{ .runner = uring };
}

pub fn queue_accept(
Expand Down Expand Up @@ -76,28 +72,60 @@ pub const AsyncIoUring = struct {

pub fn reap(self: *Async) AsyncError![]Completion {
const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(self.runner));
// NOTE: this can be dynamic and then we would just have to make a single call
// which would probably be better.
var cqes: [256]std.os.linux.io_uring_cqe = [_]std.os.linux.io_uring_cqe{undefined} ** 256;
const count = uring.copy_cqes(cqes[0..], 1) catch |e| switch (e) {
// TODO: match error states.
else => unreachable,
};
var total_reaped: u64 = 0;

const min_length = @min(cqes.len, self.completions.len);
{
// only the first one blocks waiting for an initial set of completions.
const count = uring.copy_cqes(cqes[0..min_length], 1) catch |e| switch (e) {
// TODO: match error states.
else => unreachable,
};

const min = @min(self.completions.len, count);
total_reaped += count;

for (0..min) |i| {
self.completions[i] = Completion{
.result = cqes[i].res,
.context = @ptrFromInt(@as(usize, @intCast(cqes[i].user_data))),
// Copy over the first one.
for (0..total_reaped) |i| {
self.completions[i] = Completion{
.result = cqes[i].res,
.context = @ptrFromInt(@as(usize, @intCast(cqes[i].user_data))),
};
}
}

while (total_reaped < self.completions.len) {
const start = total_reaped;
const remaining = self.completions.len - total_reaped;

const count = uring.copy_cqes(cqes[0..remaining], 0) catch |e| switch (e) {
// TODO: match error states.
else => unreachable,
};

if (count == 0) {
return self.completions[0..total_reaped];
}

total_reaped += count;

for (start..total_reaped) |i| {
const cqe_index = i - start;
self.completions[i] = Completion{
.result = cqes[cqe_index].res,
.context = @ptrFromInt(@as(usize, @intCast(cqes[cqe_index].user_data))),
};
}
}

return self.completions[0..min];
return self.completions[0..total_reaped];
}

pub fn to_async(self: *AsyncIoUring) Async {
return Async{
.runner = self.runner,
.completions = self.completions,
._queue_accept = queue_accept,
._queue_recv = queue_recv,
._queue_send = queue_send,
Expand Down
18 changes: 17 additions & 1 deletion src/async/lib.zig
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const std = @import("std");
const assert = std.debug.assert;
const builtin = @import("builtin");
const Socket = @import("../core/socket.zig").Socket;
const Completion = @import("completion.zig").Completion;
Expand Down Expand Up @@ -27,7 +28,8 @@ pub const AsyncError = error{

pub const Async = struct {
runner: *anyopaque,
completions: [256]Completion,
attached: bool = false,
completions: []Completion = undefined,

_queue_accept: *const fn (
self: *Async,
Expand Down Expand Up @@ -58,11 +60,20 @@ pub const Async = struct {
_reap: *const fn (self: *Async) AsyncError![]Completion,
_submit: *const fn (self: *Async) AsyncError!void,

/// This provides the completions that the backend will utilize when
/// submitting and reaping. This MUST be called before any other
/// methods on this Async instance.
pub fn attach(self: *Async, completions: []Completion) void {
self.completions = completions;
self.attached = true;
}

pub fn queue_accept(
self: *Async,
context: *anyopaque,
socket: Socket,
) AsyncError!void {
assert(self.attached);
try @call(.auto, self._queue_accept, .{ self, context, socket });
}

Expand All @@ -72,6 +83,7 @@ pub const Async = struct {
socket: Socket,
buffer: []u8,
) AsyncError!void {
assert(self.attached);
try @call(.auto, self._queue_recv, .{ self, context, socket, buffer });
}

Expand All @@ -81,6 +93,7 @@ pub const Async = struct {
socket: Socket,
buffer: []const u8,
) AsyncError!void {
assert(self.attached);
try @call(.auto, self._queue_send, .{ self, context, socket, buffer });
}

Expand All @@ -89,14 +102,17 @@ pub const Async = struct {
context: *anyopaque,
socket: Socket,
) AsyncError!void {
assert(self.attached);
try @call(.auto, self._queue_close, .{ self, context, socket });
}

pub fn reap(self: *Async) AsyncError![]Completion {
assert(self.attached);
return try @call(.auto, self._reap, .{self});
}

pub fn submit(self: *Async) AsyncError!void {
assert(self.attached);
try @call(.auto, self._submit, .{self});
}
};
153 changes: 96 additions & 57 deletions src/core/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const std = @import("std");
const assert = std.debug.assert;
const log = std.log.scoped(.@"zzz/server");

const Completion = @import("../async/completion.zig").Completion;
const Async = @import("../async/lib.zig").Async;
const AutoAsyncType = @import("../async/lib.zig").AutoAsyncType;
const AsyncType = @import("../async/lib.zig").AsyncType;
Expand Down Expand Up @@ -70,6 +71,11 @@ pub const zzzConfig = struct {
///
/// Default: 1024
size_connections_max: u16 = 1024,
/// Maximum number of completions we can reap
/// with a single call of reap().
///
/// Default: 256
size_completions_reap_max: u16 = 256,
/// Amount of allocated memory retained
/// after an arena is cleared.
///
Expand Down Expand Up @@ -224,7 +230,7 @@ pub fn Server(
}

/// Cleans up the TLS instance.
inline fn clean_tls(tls_ptr: *?TLS) void {
fn clean_tls(tls_ptr: *?TLS) void {
defer tls_ptr.* = null;

assert(tls_ptr.* != null);
Expand Down Expand Up @@ -309,7 +315,6 @@ pub fn Server(
.data = undefined,
};

var accepted = false;
_ = try backend.queue_accept(&first_provision, server_socket);
try backend.submit();

Expand All @@ -323,7 +328,7 @@ pub fn Server(

switch (p.job) {
.accept => {
accepted = true;
_ = try backend.queue_accept(&first_provision, server_socket);
const socket: Socket = completion.result;

if (socket < 0) {
Expand All @@ -333,6 +338,8 @@ pub fn Server(

// Borrow a provision from the pool otherwise close the socket.
const borrowed = provision_pool.borrow(@intCast(completion.result)) catch {
log.warn("out of provision pool entries", .{});
std.posix.close(socket);
continue :reap_loop;
};

Expand Down Expand Up @@ -597,11 +604,6 @@ pub fn Server(
}
}

if (!provision_pool.full and accepted) {
try backend.queue_accept(&first_provision, server_socket);
accepted = false;
}

try backend.submit();
}

Expand All @@ -625,7 +627,8 @@ pub fn Server(
switch (self.backend_type) {
.io_uring => {
// Initalize IO Uring
const base_flags = std.os.linux.IORING_SETUP_COOP_TASKRUN | std.os.linux.IORING_SETUP_SINGLE_ISSUER;
var base_flags: u32 = std.os.linux.IORING_SETUP_COOP_TASKRUN;
base_flags |= std.os.linux.IORING_SETUP_SINGLE_ISSUER;

const uring = try self.allocator.create(std.os.linux.IoUring);
uring.* = try std.os.linux.IoUring.init(
Expand All @@ -640,7 +643,18 @@ pub fn Server(
}
};

{
const completions = try self.allocator.alloc(
Completion,
self.config.size_completions_reap_max,
);

backend.attach(completions);
}

defer {
self.allocator.free(backend.completions);

switch (self.backend_type) {
.io_uring => {
const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(backend.runner));
Expand Down Expand Up @@ -672,59 +686,84 @@ pub fn Server(

// spawn (count-1) new threads.
for (0..thread_count - 1) |i| {
try threads.append(try std.Thread.spawn(.{ .allocator = allocator }, struct {
fn handler_fn(
p_config: ProtocolConfig,
z_config: zzzConfig,
p_backend: Async,
backend_type: AsyncType,
thread_tls_ctx: TLSContextType,
s_socket: Socket,
thread_id: usize,
) void {
var thread_backend = blk: {
switch (backend_type) {
.io_uring => {
const parent_uring: *std.os.linux.IoUring = @ptrCast(@alignCast(p_backend.runner));
assert(parent_uring.fd >= 0);

// Initalize IO Uring
const thread_flags = std.os.linux.IORING_SETUP_COOP_TASKRUN | std.os.linux.IORING_SETUP_SINGLE_ISSUER | std.os.linux.IORING_SETUP_ATTACH_WQ;

var params = std.mem.zeroInit(std.os.linux.io_uring_params, .{
.flags = thread_flags,
.wq_fd = @as(u32, @intCast(parent_uring.fd)),
});
try threads.append(try std.Thread.spawn(
.{ .allocator = allocator },
struct {
fn handler_fn(
p_config: ProtocolConfig,
z_config: zzzConfig,
p_backend: Async,
backend_type: AsyncType,
thread_tls_ctx: TLSContextType,
s_socket: Socket,
thread_id: usize,
) void {
var thread_backend = blk: {
switch (backend_type) {
.io_uring => {
const parent_uring: *std.os.linux.IoUring = @ptrCast(@alignCast(p_backend.runner));
assert(parent_uring.fd >= 0);

// Initalize IO Uring
var thread_flags: u32 = std.os.linux.IORING_SETUP_COOP_TASKRUN;
thread_flags |= std.os.linux.IORING_SETUP_SINGLE_ISSUER;
thread_flags |= std.os.linux.IORING_SETUP_ATTACH_WQ;

var params = std.mem.zeroInit(std.os.linux.io_uring_params, .{
.flags = thread_flags,
.wq_fd = @as(u32, @intCast(parent_uring.fd)),
});

const uring = z_config.allocator.create(std.os.linux.IoUring) catch unreachable;
uring.* = std.os.linux.IoUring.init_params(
std.math.ceilPowerOfTwoAssert(u16, z_config.size_connections_max),
&params,
) catch unreachable;

var io_uring = AsyncIoUring.init(uring) catch unreachable;
break :blk io_uring.to_async();
},
.custom => |inner| break :blk inner,
}
};

const uring = z_config.allocator.create(std.os.linux.IoUring) catch unreachable;
uring.* = std.os.linux.IoUring.init_params(
std.math.ceilPowerOfTwoAssert(u16, z_config.size_connections_max),
&params,
) catch unreachable;
{
const completions = z_config.allocator.alloc(
Completion,
z_config.size_completions_reap_max,
) catch unreachable;

var io_uring = AsyncIoUring.init(uring) catch unreachable;
break :blk io_uring.to_async();
},
.custom => |inner| break :blk inner,
thread_backend.attach(completions);
}
};

defer {
switch (backend_type) {
.io_uring => {
const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(thread_backend.runner));
uring.deinit();
z_config.allocator.destroy(uring);
},
else => {},

defer {
z_config.allocator.free(thread_backend.completions);

switch (backend_type) {
.io_uring => {
const uring: *std.os.linux.IoUring = @ptrCast(@alignCast(thread_backend.runner));
uring.deinit();
z_config.allocator.destroy(uring);
},
else => {},
}
}
}

run(z_config, p_config, &thread_backend, thread_tls_ctx, s_socket) catch |e| {
log.err("thread #{d} failed due to unrecoverable error: {any}", .{ thread_id, e });
};
}
}.handler_fn, .{ protocol_config, self.config, backend, self.backend_type, self.tls_ctx, server_socket, i }));
run(z_config, p_config, &thread_backend, thread_tls_ctx, s_socket) catch |e| {
log.err("thread #{d} failed due to unrecoverable error: {any}", .{ thread_id, e });
};
}
}.handler_fn,
.{
protocol_config,
self.config,
backend,
self.backend_type,
self.tls_ctx,
server_socket,
i,
},
));
}

run(self.config, protocol_config, &backend, self.tls_ctx, server_socket) catch |e| {
Expand Down
Loading

0 comments on commit 8442fc7

Please sign in to comment.