From 6b56293f832056756e601944bfa897c4595559cb Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Thu, 9 Jan 2025 01:47:49 +0100 Subject: [PATCH] RPC server: io_uring upgrade WIP --- src/accountsdb/snapshots.zig | 4 +- src/rpc/server.zig | 423 +++++++++----------------- src/rpc/server/LinuxIoUring.zig | 515 ++++++++++++++++++++++++++++++++ src/rpc/server/connection.zig | 154 ++++++++++ src/rpc/server/requests.zig | 203 +++++++++++++ src/utils/fmt.zig | 2 +- 6 files changed, 1014 insertions(+), 287 deletions(-) create mode 100644 src/rpc/server/LinuxIoUring.zig create mode 100644 src/rpc/server/connection.zig create mode 100644 src/rpc/server/requests.zig diff --git a/src/accountsdb/snapshots.zig b/src/accountsdb/snapshots.zig index d0c0844e8..9d685ab89 100644 --- a/src/accountsdb/snapshots.zig +++ b/src/accountsdb/snapshots.zig @@ -2225,7 +2225,7 @@ pub const FullSnapshotFileInfo = struct { slot: Slot, hash: Hash, - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .slot = std.math.maxInt(Slot), @@ -2352,7 +2352,7 @@ pub const IncrementalSnapshotFileInfo = struct { }; } - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .base_slot = std.math.maxInt(Slot), diff --git a/src/rpc/server.zig b/src/rpc/server.zig index c320e3900..b8a487f27 100644 --- a/src/rpc/server.zig +++ b/src/rpc/server.zig @@ -1,21 +1,24 @@ +const builtin = @import("builtin"); const std = @import("std"); const sig = @import("../sig.zig"); +const connection = @import("server/connection.zig"); +const requests = @import("server/requests.zig"); + +const IoUring = std.os.linux.IoUring; + const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; const ThreadPool = sig.sync.ThreadPool; -const LOGGER_SCOPE = "rpc.Server"; -const ScopedLogger = sig.trace.log.ScopedLogger(LOGGER_SCOPE); - pub const Server = struct { //! Basic usage: //! ```zig //! var server = try Server.init(.{...}); //! defer server.joinDeinit(); //! - //! try server.serveSpawnDetached(); // or `.serveDirect`, if the caller can block or is managing the separate thread themselves. + //! try server.serve(); // or `.serveSpawn` to spawn a thread and return its handle. //! ``` allocator: std.mem.Allocator, @@ -27,13 +30,27 @@ pub const Server = struct { /// Wait group for all currently running tasks, used to wait for /// all of them to finish before deinitializing. wait_group: std.Thread.WaitGroup, - thread_pool: *ThreadPool, + work_pool: WorkPool, + tcp: std.net.Server, /// Must not be mutated. read_buffer_size: usize, - tcp: std.net.Server, - pub const MIN_READ_BUFFER_SIZE = 256; + pub const LOGGER_SCOPE = "rpc.Server"; + pub const ScopedLogger = sig.trace.log.ScopedLogger(LOGGER_SCOPE); + + pub const MIN_READ_BUFFER_SIZE = 4096; + + pub const InitError = + std.net.Address.ListenError || + std.posix.MMapError || + std.posix.UnexpectedError || + WorkPool.LinuxIoUring.InitError || + WorkPool.LinuxIoUring.EnterError || + error{ + SubmissionQueueFull, + FailedToAcceptMultishot, + }; /// The returned result must be pinned to a memory location before calling any methods. pub fn init(params: struct { @@ -47,14 +64,15 @@ pub const Server = struct { /// given time with respect to the contents of the specified `snapshot_dir`. latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), - thread_pool: *ThreadPool, - /// The size for the read buffer allocated to every request. /// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. read_buffer_size: u32, /// The socket address to listen on for incoming HTTP and/or RPC requests. socket_addr: std.net.Address, - }) std.net.Address.ListenError!Server { + + /// Set to true to disable taking advantage of native work pool strategies (ie io_uring). + force_basic_work_pool: bool = false, + }) InitError!Server { var tcp_server = try params.socket_addr.listen(.{ // NOTE: ideally we would be doing this nonblockingly, however this doesn't work properly on mac, // so for testing purposes we can't test the `serve` functionality directly. @@ -62,6 +80,39 @@ pub const Server = struct { }); errdefer tcp_server.deinit(); + var work_pool: WorkPool = if (params.force_basic_work_pool) + .basic + else switch (WorkPool.LinuxIoUring.can_use) { + .no => .basic, + .yes, .check => |can_use| blk: { + var io_uring = IoUring.init(32, 0) catch |err| return switch (err) { + error.SystemOutdated, + error.PermissionDenied, + => |e| switch (can_use) { + .yes => e, + .check => break :blk .basic, + .no => comptime unreachable, + }, + else => |e| e, + }; + errdefer io_uring.deinit(); + + _ = try io_uring.accept_multishot( + @bitCast(WorkPool.LinuxIoUring.Entry.ACCEPT), + tcp_server.stream.handle, + null, + null, + std.os.linux.SOCK.CLOEXEC, + ); + if (try io_uring.submit() != 1) { + return error.FailedToAcceptMultishot; + } + + break :blk .{ .linux_io_uring = .{ .io_uring = io_uring } }; + }, + }; + errdefer work_pool.deinit(); + return .{ .allocator = params.allocator, .logger = params.logger.withScope(LOGGER_SCOPE), @@ -70,7 +121,7 @@ pub const Server = struct { .latest_snapshot_gen_info = params.latest_snapshot_gen_info, .wait_group = .{}, - .thread_pool = params.thread_pool, + .work_pool = work_pool, .read_buffer_size = @max(params.read_buffer_size, MIN_READ_BUFFER_SIZE), .tcp = tcp_server, @@ -98,295 +149,99 @@ pub const Server = struct { exit: *std.atomic.Value(bool), ) AcceptAndServeConnectionError!void { while (!exit.load(.acquire)) { - try server.acceptAndServeConnection(); + try server.acceptAndServeConnection(.{}); } } pub const AcceptAndServeConnectionError = std.mem.Allocator.Error || std.http.Server.ReceiveHeadError || - AcceptConnectionError; - - pub fn acceptAndServeConnection(server: *Server) AcceptAndServeConnectionError!void { - const conn = (try acceptConnection(&server.tcp, server.logger)).?; - errdefer conn.stream.close(); - - server.wait_group.start(); - errdefer server.wait_group.finish(); - - const new_hct = try HandleConnectionTask.createAndReceiveHead(server, conn); - errdefer new_hct.destroyAndClose(); - - server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task)); - } -}; + WorkPool.LinuxIoUring.EnterError || + WorkPool.LinuxIoUring.AcceptAndServeConnectionsError || + AcceptHandledError || + requests.HandleRequestError; -const HandleConnectionTask = struct { - task: ThreadPool.Task, - server: *Server, - http_server: std.http.Server, - request: std.http.Server.Request, - - fn createAndReceiveHead( + pub fn acceptAndServeConnection( server: *Server, - conn: std.net.Server.Connection, - ) (std.http.Server.ReceiveHeadError || std.mem.Allocator.Error)!*HandleConnectionTask { - const allocator = server.allocator; - - const hct_buf_align = @alignOf(HandleConnectionTask); - const hct_buf_size = initBufferSize(server.read_buffer_size); - - const hct_buffer = try allocator.alignedAlloc(u8, hct_buf_align, hct_buf_size); - errdefer server.allocator.free(hct_buffer); - - const hct: *HandleConnectionTask = std.mem.bytesAsValue( - HandleConnectionTask, - hct_buffer[0..@sizeOf(HandleConnectionTask)], - ); - hct.* = .{ - .task = .{ .callback = callback }, - .server = server, - .http_server = std.http.Server.init(conn, getReadBuffer(server.read_buffer_size, hct)), - .request = try hct.http_server.receiveHead(), - }; - - return hct; - } - - /// Does not release the connection. - fn destroyAndClose(hct: *HandleConnectionTask) void { - const allocator = hct.server.allocator; - - const full_buffer = getFullBuffer(hct.server.read_buffer_size, hct); - defer allocator.free(full_buffer); - - const connection = hct.http_server.connection; - defer connection.stream.close(); - } - - fn initBufferSize(read_buffer_size: usize) usize { - return @sizeOf(HandleConnectionTask) + read_buffer_size; - } - - fn getFullBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []align(@alignOf(HandleConnectionTask)) u8 { - const ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct); - return ptr[0..initBufferSize(read_buffer_size)]; - } - - fn getReadBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []u8 { - return getFullBuffer(read_buffer_size, hct)[@sizeOf(HandleConnectionTask)..]; + options: struct { + /// The maximum number of connections to handle during this call. + max_connections_to_handle: u8 = std.math.maxInt(u8), + }, + ) AcceptAndServeConnectionError!void { + switch (server.work_pool) { + .basic => { + const conn = try acceptHandled(&server.tcp); + defer conn.stream.close(); + + server.wait_group.start(); + defer server.wait_group.finish(); + + const buffer = try server.allocator.alloc(u8, server.read_buffer_size); + defer server.allocator.free(buffer); + + var http_server = std.http.Server.init(conn, buffer); + var request = try http_server.receiveHead(); + + try requests.handleRequest( + server.logger, + &request, + server.snapshot_dir, + server.latest_snapshot_gen_info, + ); + }, + .linux_io_uring => |*linux| { + try linux.acceptAndServeConnections(server, options.max_connections_to_handle); + }, + } } +}; - fn callback(task: *ThreadPool.Task) void { - const hct: *HandleConnectionTask = @fieldParentPtr("task", task); - defer hct.destroyAndClose(); - - const server = hct.server; - const logger = server.logger; +pub const WorkPool = union(enum) { + basic, + linux_io_uring: switch (LinuxIoUring.can_use) { + .yes, .check => LinuxIoUring, + .no => noreturn, + }, - const wait_group = &server.wait_group; - defer wait_group.finish(); + const LinuxIoUring = @import("server/LinuxIoUring.zig"); - handleRequest( - logger, - &hct.request, - server.snapshot_dir, - server.latest_snapshot_gen_info, - ) catch |err| { - if (@errorReturnTrace()) |stack_trace| { - logger.err().logf("{s}\n{}", .{ @errorName(err), stack_trace }); - } else { - logger.err().logf("{s}", .{@errorName(err)}); - } - }; + pub fn deinit(wp: *WorkPool) void { + switch (wp.*) { + .basic => {}, + .linux_io_uring => |*linux| linux.deinit(), + } } }; -fn handleRequest( - logger: ScopedLogger, - request: *std.http.Server.Request, - snapshot_dir: std.fs.Dir, - latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), -) !void { - const conn_address = request.server.connection.address; - - logger.info().logf("Responding to request from {}: {} {s}", .{ - conn_address, methodFmt(request.head.method), request.head.target, - }); - switch (request.head.method) { - .POST => { - logger.err().logf("{} tried to invoke our RPC", .{conn_address}); - return try request.respond("RPCs are not yet implemented", .{ - .status = .service_unavailable, - .keep_alive = false, - }); - }, - .GET => get_blk: { - if (!std.mem.startsWith(u8, request.head.target, "/")) break :get_blk; - const path = request.head.target[1..]; - - // we hold the lock for the entirety of this process in order to prevent - // the snapshot generation process from deleting the associated snapshot. - const maybe_latest_snapshot_gen_info, // - var latest_snapshot_info_lg // - = latest_snapshot_gen_info_rw.readWithLock(); - defer latest_snapshot_info_lg.unlock(); - - const full_info: ?FullSnapshotFileInfo, // - const inc_info: ?IncrementalSnapshotFileInfo // - = blk: { - const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse - break :blk .{ null, null }; - const latest_full = latest_snapshot_gen_info.full; - const full_info: FullSnapshotFileInfo = .{ - .slot = latest_full.slot, - .hash = latest_full.hash, - }; - const latest_incremental = latest_snapshot_gen_info.inc orelse - break :blk .{ full_info, null }; - const inc_info: IncrementalSnapshotFileInfo = .{ - .base_slot = latest_full.slot, - .slot = latest_incremental.slot, - .hash = latest_incremental.hash, - }; - break :blk .{ full_info, inc_info }; - }; - - logger.debug().logf("Available full: {?s}", .{ - if (full_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - logger.debug().logf("Available inc: {?s}", .{ - if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - - if (full_info) |full| { - const full_archive_name_bounded = full.snapshotArchiveName(); - const full_archive_name = full_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, full_archive_name)) { - const archive_file = try snapshot_dir.openFile(full_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - - if (inc_info) |inc| { - const inc_archive_name_bounded = inc.snapshotArchiveName(); - const inc_archive_name = inc_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, inc_archive_name)) { - const archive_file = try snapshot_dir.openFile(inc_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - }, - else => {}, - } - - logger.err().logf( - "{} made an unrecognized request '{} {s}'", - .{ conn_address, methodFmt(request.head.method), request.head.target }, - ); - try request.respond("", .{ - .status = .not_found, - .keep_alive = false, - }); -} - -fn httpResponseSendFile( - request: *std.http.Server.Request, - archive_file: std.fs.File, - send_buffer: []u8, -) !void { - const archive_len = try archive_file.getEndPos(); - - var response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .content_length = archive_len, - }); - const writer = sig.utils.io.narrowAnyWriter( - response.writer(), - std.http.Server.Response.WriteError, - ); - - const Fifo = std.fifo.LinearFifo(u8, .{ .Static = 1 }); - var fifo: Fifo = Fifo.init(); - try archive_file.seekTo(0); - try fifo.pump(archive_file.reader(), writer); - - try response.end(); -} - -const AcceptConnectionError = error{ - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, - ProtocolFailure, - BlockedByFirewall, - NetworkSubsystemFailed, -} || std.posix.UnexpectedError; - -fn acceptConnection( +const AcceptHandledError = connection.HandleAcceptError || error{ConnectionAborted}; +fn acceptHandled( tcp_server: *std.net.Server, - logger: ScopedLogger, -) AcceptConnectionError!?std.net.Server.Connection { - const conn = tcp_server.accept() catch |err| switch (err) { - error.Unexpected, - => |e| return e, - - error.ProcessFdQuotaExceeded, - error.SystemFdQuotaExceeded, - error.SystemResources, - error.ProtocolFailure, - error.BlockedByFirewall, - error.NetworkSubsystemFailed, - => |e| return e, - - error.FileDescriptorNotASocket, - error.SocketNotListening, - error.OperationNotSupported, - => @panic("Improperly initialized server."), - - error.WouldBlock, - => return null, - - error.ConnectionResetByPeer, - error.ConnectionAborted, - => |e| { - logger.warn().logf("{}", .{e}); - return null; - }, - }; - - return conn; -} - -fn methodFmt(method: std.http.Method) MethodFmt { - return .{ .method = method }; -} +) AcceptHandledError!std.net.Server.Connection { + while (true) { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + const rc = if (!builtin.target.isDarwin()) std.posix.system.accept4( + tcp_server.stream.handle, + &addr.any, + &addr_len, + std.posix.SOCK.CLOEXEC, + ) else std.posix.system.accept( + tcp_server.stream.handle, + &addr.any, + &addr_len, + ); -const MethodFmt = struct { - method: std.http.Method, - pub fn format( - fmt: MethodFmt, - comptime fmt_str: []const u8, - fmt_options: std.fmt.FormatOptions, - writer: anytype, - ) @TypeOf(writer).Error!void { - _ = fmt_options; - if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); - try fmt.method.write(writer); + return switch (try connection.handleAcceptResult(std.posix.errno(rc))) { + .intr => continue, + .conn_aborted => return error.ConnectionAborted, + .again => std.debug.panic("We're not using nonblock, but encountered EAGAIN.", .{}), + .success => return .{ + .stream = .{ .handle = rc }, + .address = addr, + }, + }; } -}; +} test Server { const allocator = std.testing.allocator; @@ -481,9 +336,9 @@ test Server { .logger = logger, .snapshot_dir = snap_dir, .latest_snapshot_gen_info = &accountsdb.latest_snapshot_gen_info, - .thread_pool = &thread_pool, .socket_addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, rpc_port), .read_buffer_size = 4096, + .force_basic_work_pool = false, }); defer rpc_server.joinDeinit(); @@ -517,7 +372,7 @@ fn testExpectSnapshotResponse( ); const snap_url = try std.Uri.parse(snap_url_str_bounded.constSlice()); - const serve_thread = try std.Thread.spawn(.{}, Server.acceptAndServeConnection, .{rpc_server}); + const serve_thread = try std.Thread.spawn(.{}, Server.acceptAndServeConnection, .{ rpc_server, .{} }); const actual_data = try testDownloadSelfSnapshot(allocator, snap_url); defer allocator.free(actual_data); serve_thread.join(); diff --git a/src/rpc/server/LinuxIoUring.zig b/src/rpc/server/LinuxIoUring.zig new file mode 100644 index 000000000..dd1df36e4 --- /dev/null +++ b/src/rpc/server/LinuxIoUring.zig @@ -0,0 +1,515 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); + +const connection = @import("connection.zig"); +const requests = @import("requests.zig"); + +const IoUring = std.os.linux.IoUring; + +const Server = sig.rpc.Server; +const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; + +const LinuxIoUring = @This(); +io_uring: IoUring, + +fn deinit(linux: *LinuxIoUring) void { + linux.io_uring.deinit(); +} + +pub const can_use: enum { no, yes, check } = switch (builtin.os.getVersionRange()) { + .linux => |version| can_use: { + const min_version: std.SemanticVersion = .{ .major = 6, .minor = 0, .patch = 0 }; + const is_at_least = version.isAtLeast(min_version) orelse break :can_use .check; + break :can_use if (is_at_least) .yes else .no; + }, + else => .no, +}; + +pub const InitError = std.posix.MMapError || error{ + EntriesZero, + EntriesNotPowerOfTwo, + ParamsOutsideAccessibleAddressSpace, + ArgumentsInvalid, + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + PermissionDenied, + SystemOutdated, +}; + +pub const EnterError = error{ + SystemResources, + FileDescriptorInvalid, + FileDescriptorInBadState, + CompletionQueueOvercommitted, + SubmissionQueueEntryInvalid, + BufferInvalid, + RingShuttingDown, + OpcodeNotSupported, + SignalInterrupt, +}; + +pub const AcceptAndServeConnectionsError = + // std.posix.GetSockNameError || + std.mem.Allocator.Error || + connection.HandleAcceptError || + connection.HandleRecvError || + EnterError || + std.http.Server.Request.Head.ParseError || + error{RequestBodyTooLong} || + error{SubmissionQueueFull}; + +pub fn acceptAndServeConnections( + linux: *LinuxIoUring, + server: *Server, + max_connections_to_handle: u8, +) AcceptAndServeConnectionsError!void { + std.debug.assert(linux == &server.work_pool.linux_io_uring); + _ = try linux.io_uring.submit_and_wait(1); + + var cqes_buf: [255]std.os.linux.io_uring_cqe = undefined; + const cqes = cqes: { + const cqes_count = try linux.io_uring.copy_cqes(cqes_buf[0..max_connections_to_handle], 0); + break :cqes cqes_buf[0..cqes_count]; + }; + + var first_err: ?AcceptAndServeConnectionsError = null; + + cqe_loop: for (cqes, 0..) |cqe, i| { + errdefer for (cqes[i..]) |next_cqe| { // including the current cqe + const next_entry: Entry = @bitCast(next_cqe.user_data); + next_entry.deinit(server.allocator); + }; + + const entry: Entry = @bitCast(cqe.user_data); + const entry_data: *EntryData = entry.ptr orelse { + // multishot accept cqe + + if (connection.handleAcceptResult(cqe.err())) |accept_result| switch (accept_result) { + .success => {}, + .intr => std.debug.panic("TODO: does this mean the multishot accept has stopped? If no, just warn. If yes, re-queue here and warn.", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + .conn_aborted => return, + } else |err| { + first_err = first_err orelse err; + continue :cqe_loop; + } + + const stream: std.net.Stream = .{ .handle = cqe.res }; + errdefer stream.close(); + + server.wait_group.start(); + errdefer server.wait_group.finish(); + + const buffer = try server.allocator.alloc(u8, server.read_buffer_size); + errdefer server.allocator.free(buffer); + + const new_recv_entry: Entry = entry: { + const data_ptr = try server.allocator.create(EntryData); + errdefer comptime unreachable; + + data_ptr.* = .{ .recv = EntryData.State.INIT }; + break :entry .{ .ptr = data_ptr }; + }; + errdefer if (new_recv_entry.ptr) |data_ptr| server.allocator.destroy(data_ptr); + + _ = try linux.io_uring.recv( + @bitCast(new_recv_entry), + stream.handle, + .{ .buffer = buffer }, + 0, + ); + + continue :cqe_loop; + }; + + switch (entry_data.state) { + .recv => |*recv_data| { + if (connection.handleRecvResult(cqe.err())) |accept_result| switch (accept_result) { + .success => {}, + + .intr => std.debug.panic("TODO: how to handle interrupts on this?", .{}), // TODO: + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + + .conn_refused, + .conn_reset, + .timed_out, + => |tag| { + if (connection.getSockName(recv_data.stream.handle)) |addr| + server.logger.warn().logf("{s} ({})", .{ @tagName(tag), addr }) + else |_| + server.logger.warn().logf("{s} (unnamed connection?)", .{@tagName(tag)}); + entry.deinit(server.allocator); + continue :cqe_loop; + }, + } else |err| { + server.logger.err().logf("{s}", .{@errorName(err)}); + first_err = first_err orelse err; + entry.deinit(server.allocator); + continue :cqe_loop; + } + + const recv_len: usize = @intCast(cqe.res); + const body = switch (recv_data.*) { + .head => |*head| body: { + std.debug.assert(head.parser.state != .finished); + + const recv_start = head.end; + const recv_end = recv_start + recv_len; + head.end += head.parser.feed(recv_data.buffer[recv_start..recv_end]); + + if (head.parser.state != .finished) { + std.debug.assert(head.end == recv_end); + + if (head.end == recv_data.buffer.len) { + std.debug.panic("TODO: handle a too-big head", .{}); // TODO: + } + + _ = try linux.io_uring.recv( + @bitCast(entry), + recv_data.stream.handle, + .{ .buffer = recv_data.buffer[head.end..] }, + 0, + ); + continue :cqe_loop; + } + + const method: std.http.Method, // + const target: std.BoundedArray(u8, requests.MAX_TARGET_LEN), // + const content_len: ?usize // + = blk: { + const head_bytes = recv_data.buffer[0..head.end]; + const parsed_head = std.http.Server.Request.Head.parse(head_bytes) catch |err| { + server.logger.err().logf("{s}", .{@errorName(err)}); + first_err = first_err orelse err; + entry.deinit(server.allocator); + continue :cqe_loop; + }; + + var target: std.BoundedArray(u8, requests.MAX_TARGET_LEN) = .{}; + target.appendSlice(parsed_head.target) catch { + if (connection.getSockName(recv_data.stream.handle)) |addr| + server.logger.err().logf("{} requested a target '{s}', too long", .{ addr, parsed_head.target }) + else |_| + server.logger.err().logf("Unnamed connection requested a target '{s}', too long", .{parsed_head.target}); + entry.deinit(server.allocator); + continue :cqe_loop; + }; + + if (parsed_head.transfer_encoding != .none) std.debug.panic("TODO: handle", .{}); // TODO: + if (parsed_head.transfer_compression != .identity) std.debug.panic("TODO: handle", .{}); // TODO: + + break :blk .{ parsed_head.method, target, parsed_head.content_length }; + }; + + const content_end = blk: { + const old_content_bytes = recv_data.buffer[head.end..recv_end]; + std.mem.copyForwards( + u8, + recv_data.buffer[0..old_content_bytes.len], + old_content_bytes, + ); + break :blk old_content_bytes.len; + }; + + if (content_len) |len| { + if (len < content_end) { + server.logger.err().logf( + "HTTP Request body ({}) longer than declared content_length {}", + .{ + std.fmt.fmtIntSizeDec(content_end), + std.fmt.fmtIntSizeDec(len), + }, + ); + first_err = first_err orelse error.RequestBodyTooLong; + entry.deinit(server.allocator); + continue :cqe_loop; + } + + if (len > recv_data.buffer.len) { + std.debug.assert(len >= content_end); + if (server.allocator.resize(recv_data.buffer, len)) { + recv_data.buffer.len = len; + } else { + const new_mem = try server.allocator.alloc(u8, len); + server.allocator.free(recv_data.buffer); + recv_data.buffer = new_mem; + } + } + } + + recv_data.* = .{ .body = .{ + .head_method = method, + .head_target = target, + .content_len = content_len, + .end = content_end, + } }; + const body = &recv_data.state.body; + + if (content_len) |len| { + if (len == body.end) break :body body; + _ = try linux.io_uring.recv( + @bitCast(entry), + recv_data.stream.handle, + .{ .buffer = recv_data.buffer[content_end..] }, + 0, + ); + continue :cqe_loop; + } else { + if (body.end != 0) server.logger.warn().logf( // + "HTTP request sent unexpected body without content_length." ++ + " Ignoring." // + , .{}); + break :body body; + } + }, + .body => |*body| body: { + body.end += recv_len; + break :body body; + }, + }; + + if (body.content_len) |len| { + if (body.end < len) { + _ = try linux.io_uring.recv( + @bitCast(entry), + recv_data.stream.handle, + .{ .buffer = recv_data.buffer[body.end..len] }, + 0, + ); + continue :cqe_loop; + } + } + + const content_bytes: []const u8 = recv_data.buffer[0..body.end]; + switch (body.head_method) { + .GET => { + if (content_bytes.len != 0) { + if (connection.getSockName(recv_data.stream.handle)) |addr| { + server.logger.warn().logf( + "{} sent a GET request with" ++ + " a non-empty body ({}).", + .{ addr, std.fmt.fmtIntSizeDec(content_bytes.len) }, + ); + } else |_| { + server.logger.warn().logf( + "Unnamed connection sent a GET request with" ++ + " a non-empty body ({}).", + .{std.fmt.fmtIntSizeDec(content_bytes.len)}, + ); + } + } + + switch (requests.getRequestTargetResolve( + server.logger, + body.head_target.constSlice(), + server.latest_snapshot_gen_info, + )) { + inline .full_snapshot, .inc_snapshot => |pair| { + const snap_info, var full_info_lg = pair; + errdefer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const snapshot_dir = server.snapshot_dir; + const archive_file = try snapshot_dir.openFile(archive_name, .{}); + errdefer archive_file.close(); + const file_size = try archive_file.getEndPos(); + + const pipe_r, const pipe_w = try std.posix.pipe(); + errdefer std.posix.close(pipe_w); + errdefer std.posix.close(pipe_r); + + entry_data.state = .{ .send = .{ + .file_lg = full_info_lg, + .file = archive_file, + .file_size = file_size, + + .pipe_w = pipe_w, + .pipe_r = pipe_r, + + .spliced_to_pipe = 0, + .spliced_to_socket = 0, + .which = .to_pipe, + } }; + const send_data = &entry_data.state.send; + try send_data.prepSpliceFileToSocket(entry, &linux.io_uring); + + continue :cqe_loop; + }, + .unrecognized => {}, + } + }, + + .POST => {}, + else => {}, + } + + @panic("TODO: handle unhandled"); // TODO: + }, + .send => |*send_data| switch (send_data.which) { + .to_pipe => { + if (connection.handleSpliceResult(cqe.err())) |accept_result| switch (accept_result) { + .success => {}, + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } else |err| { + server.logger.err().logf("{s}", .{@errorName(err)}); + first_err = first_err orelse err; + entry.deinit(server.allocator); + continue :cqe_loop; + } + + send_data.spliced_to_pipe += @intCast(cqe.res); + send_data.which = .to_socket; + }, + .to_socket => { + if (connection.handleSpliceResult(cqe.err())) |accept_result| switch (accept_result) { + .success => {}, + .again => std.debug.panic("The socket should not be in nonblocking mode.", .{}), + } else |err| { + server.logger.err().logf("{s}", .{@errorName(err)}); + first_err = first_err orelse err; + entry.deinit(server.allocator); + continue :cqe_loop; + } + + send_data.spliced_to_socket += @intCast(cqe.res); + send_data.which = .to_pipe; + + if (send_data.spliced_to_socket == send_data.file_size) { + std.debug.assert(send_data.spliced_to_socket == send_data.spliced_to_pipe); + entry.deinit(server.allocator); + } else { + try send_data.prepSpliceFileToSocket(entry, &linux.io_uring); + } + + continue :cqe_loop; + }, + }, + } + } + + return first_err orelse {}; +} + +pub const Entry = packed struct(u64) { + /// If null, this is an `accept` entry. + ptr: ?*EntryData, + + pub const ACCEPT: Entry = .{ .ptr = null }; + + pub fn deinit(entry: Entry, allocator: std.mem.Allocator) void { + const ptr = entry.ptr orelse return; + ptr.deinit(allocator); + allocator.destroy(ptr); + } +}; + +pub const EntryData = struct { + buffer: []u8, + stream: std.net.Stream, + state: State, + + fn init(buffer: []u8, stream: std.net.Stream) Entry { + return .{ + .buffer = buffer, + .stream = stream, + .state = State.INIT, + }; + } + + fn deinit(data: *EntryData, allocator: std.mem.Allocator) void { + data.state.deinit(); + allocator.free(data.buffer); + data.stream.close(); + } + + pub const State = union(enum) { + recv: Recv, + send: Send, + + pub const INIT: State = .{ + .recv = .{ + .head = .{ + .end = 0, + .parser = .{}, + }, + }, + }; + + pub fn deinit(state: *State) void { + switch (state) { + .recv => {}, + .send => |*send_data| send_data.deinit(), + } + } + + pub const Recv = union(enum) { + head: Head, + body: Body, + + pub const Head = struct { + end: usize, + parser: std.http.HeadParser, + }; + + pub const Body = struct { + head_method: std.http.Method, + head_target: std.BoundedArray(u8, requests.MAX_TARGET_LEN), + content_len: ?usize, + end: usize, + }; + }; + + pub const Send = struct { + file_lg: requests.GetRequestTargetResolved.SnapshotReadLock, + file: std.fs.File, + file_size: u64, + + pipe_w: std.os.linux.fd_t, + pipe_r: std.os.linux.fd_t, + + spliced_to_pipe: u64, + spliced_to_socket: u64, + which: Which, + + pub const Which = enum { + to_pipe, + to_socket, + }; + + pub fn deinit(self: *Send) void { + self.file.close(); + self.file_lg.unlock(); + std.posix.close(self.pipe_w); + std.posix.close(self.pipe_r); + } + + fn prepSpliceFileToSocket(self: *const Send, entry: Entry, io_uring: *IoUring) !void { + std.debug.assert(self == &entry.ptr.?.state.send); + const stream = entry.ptr.?.stream; + const splice1_sqe = try io_uring.splice( + @bitCast(entry), + self.file.handle, + self.spliced_to_pipe, + self.pipe_w, + std.math.maxInt(u64), + self.file_size - self.spliced_to_pipe, + ); + splice1_sqe.flags |= std.os.linux.IOSQE_IO_LINK; + + const splice2_sqe = try io_uring.splice( + @bitCast(entry), + self.pipe_r, + std.math.maxInt(u64), + stream.handle, + std.math.maxInt(u64), + self.file_size - self.spliced_to_socket, + ); + _ = splice2_sqe; + } + }; + }; +}; diff --git a/src/rpc/server/connection.zig b/src/rpc/server/connection.zig new file mode 100644 index 000000000..e7f0d854e --- /dev/null +++ b/src/rpc/server/connection.zig @@ -0,0 +1,154 @@ +const builtin = @import("builtin"); +const std = @import("std"); + +pub fn getSockName( + socket_handle: std.posix.socket_t, +) std.posix.GetSockNameError!std.net.Address { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + try std.posix.getsockname(socket_handle, &addr.any, &addr_len); + return addr; +} + +pub const WithLazyAddr = struct { + stream: std.net.Stream, + address: ?std.net.Address, + + pub fn toStdConnection(self: WithLazyAddr) std.posix.GetSockNameError!std.net.Server.Connection { + return .{ + .stream = self.stream, + .address = try self.getAddress(), + }; + } + + pub fn getAddress(self: WithLazyAddr) std.posix.GetSockNameError!std.net.Address { + return self.address orelse try getSockName(self.stream.handle); + } + + pub fn getAndCacheAddress(self: *WithLazyAddr) std.posix.GetSockNameError!std.net.Address { + const address = try self.getAddress(); + self.address = address; + return address; + } +}; + +pub const HandleAcceptError = error{ + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + ProtocolFailure, + BlockedByFirewall, +} || std.posix.UnexpectedError; + +pub const HandleAcceptResult = enum { + success, + intr, + again, + conn_aborted, +}; + +/// Resembles the error handling of `std.posix.accept`. +pub fn handleAcceptResult( + /// Must be the result of `std.posix.accept` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleAcceptError!HandleAcceptResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNABORTED => .conn_aborted, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .NOTSOCK, // don't call accept on a non-socket + .OPNOTSUPP, // socket must support accept + .INVAL, // socket must be listening + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .MFILE => return error.ProcessFdQuotaExceeded, + .NFILE => return error.SystemFdQuotaExceeded, + .NOBUFS => return error.SystemResources, + .NOMEM => return error.SystemResources, + .PROTO => return error.ProtocolFailure, + .PERM => return error.BlockedByFirewall, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleRecvError = error{ + SystemResources, +} || std.posix.UnexpectedError; + +pub const HandleRecvResult = enum { + success, + intr, + again, + conn_refused, + conn_reset, + timed_out, +}; + +/// Resembles the error handling of `std.posix.recv`. +pub fn handleRecvResult( + /// Must be the result of `std.posix.recv` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleRecvError!HandleRecvResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNREFUSED => .conn_refused, + .CONNRESET => .conn_reset, + .TIMEDOUT => .timed_out, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .INVAL, // socket must be listening + .NOTSOCK, // don't call accept on a non-socket + .NOTCONN, // we should always be connected + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .NOMEM => return error.SystemResources, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleSpliceError = error{ + /// One or both file descriptors are not valid, or do not have proper read-write mode. + BadFileDescriptors, + /// Either off_in or off_out was not NULL, but the corresponding file descriptor refers to a pipe. + BadFdOffset, + /// Could be one of many reasons, see the manpage for splice. + InvalidSplice, + /// Out of memory. + SystemResources, +}; + +pub const HandleSpliceResult = enum { + success, + again, +}; + +pub fn handleSpliceResult( + /// Must be the result of calling the `splice` syscall or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleSpliceError!HandleSpliceResult { + comptime std.debug.assert( // + builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .AGAIN => .again, + .INVAL => return error.InvalidSplice, + .SPIPE => return error.BadFdOffset, + .BADF => return error.BadFileDescriptors, + .NOMEM => return error.SystemResources, + else => |err| std.posix.unexpectedErrno(err), + }; +} diff --git a/src/rpc/server/requests.zig b/src/rpc/server/requests.zig new file mode 100644 index 000000000..5135fb67b --- /dev/null +++ b/src/rpc/server/requests.zig @@ -0,0 +1,203 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); +const connection = @import("connection.zig"); + +const IoUring = std.os.linux.IoUring; + +const Server = sig.rpc.Server; +const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; +const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; +const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; + +pub const MAX_TARGET_LEN: usize = blk: { + const SnapSpec = IncrementalSnapshotFileInfo.SnapshotArchiveNameFmtSpec; + break :blk "/".len + SnapSpec.fmtLenValue(.{ + .base_slot = std.math.maxInt(sig.core.Slot), + .slot = std.math.maxInt(sig.core.Slot), + .hash = sig.core.Hash.base58String(.{ .data = .{255} ** sig.core.Hash.size }).constSlice(), + }); +}; + +pub const GetRequestTargetResolved = union(enum) { + unrecognized, + full_snapshot: struct { FullSnapshotFileInfo, SnapshotReadLock }, + inc_snapshot: struct { IncrementalSnapshotFileInfo, SnapshotReadLock }, + + // TODO: also handle the snapshot archive aliases & other routes + + pub const SnapshotReadLock = sig.sync.RwMux(?SnapshotGenerationInfo).RLockGuard; +}; + +/// Resolve a `GET` request target. +pub fn getRequestTargetResolve( + logger: Server.ScopedLogger, + target: []const u8, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) GetRequestTargetResolved { + if (!std.mem.startsWith(u8, target, "/")) return .unrecognized; + const path = target[1..]; + + const is_snapshot_archive_like = + !std.meta.isError(FullSnapshotFileInfo.parseFileNameTarZst(path)) or + !std.meta.isError(IncrementalSnapshotFileInfo.parseFileNameTarZst(path)); + + if (is_snapshot_archive_like) { + // we hold the lock for the entirety of this process in order to prevent + // the snapshot generation process from deleting the associated snapshot. + const maybe_latest_snapshot_gen_info, // + var latest_snapshot_info_lg // + = latest_snapshot_gen_info_rw.readWithLock(); + errdefer latest_snapshot_info_lg.unlock(); + + const full_info: ?FullSnapshotFileInfo, // + const inc_info: ?IncrementalSnapshotFileInfo // + = blk: { + const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse + break :blk .{ null, null }; + const latest_full = latest_snapshot_gen_info.full; + const full_info: FullSnapshotFileInfo = .{ + .slot = latest_full.slot, + .hash = latest_full.hash, + }; + const latest_incremental = latest_snapshot_gen_info.inc orelse + break :blk .{ full_info, null }; + const inc_info: IncrementalSnapshotFileInfo = .{ + .base_slot = latest_full.slot, + .slot = latest_incremental.slot, + .hash = latest_incremental.hash, + }; + break :blk .{ full_info, inc_info }; + }; + + logger.debug().logf("Available full: {?s}", .{ + if (full_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + logger.debug().logf("Available inc: {?s}", .{ + if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + + if (full_info) |full| { + const full_archive_name_bounded = full.snapshotArchiveName(); + const full_archive_name = full_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, full_archive_name)) { + return .{ .full_snapshot = .{ full, latest_snapshot_info_lg } }; + } + } + + if (inc_info) |inc| { + const inc_archive_name_bounded = inc.snapshotArchiveName(); + const inc_archive_name = inc_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, inc_archive_name)) { + return .{ .inc_snapshot = .{ inc, latest_snapshot_info_lg } }; + } + } + } + + return .unrecognized; +} + +pub const HandleRequestError = + std.fs.File.OpenError || + HttpResponseSendFileError; + +pub fn handleRequest( + logger: Server.ScopedLogger, + request: *std.http.Server.Request, + snapshot_dir: std.fs.Dir, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) !void { + const conn_address = request.server.connection.address; + logger.info().logf("Responding to request from {}: {} {s}", .{ + conn_address, methodFmt(request.head.method), request.head.target, + }); + + switch (request.head.method) { + .GET => switch (getRequestTargetResolve( + logger, + request.head.target, + latest_snapshot_gen_info_rw, + )) { + .unrecognized => {}, + inline .full_snapshot, .inc_snapshot => |pair| { + const snap_info, var full_info_lg = pair; + defer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const archive_file = try snapshot_dir.openFile(archive_name, .{}); + defer archive_file.close(); + + var send_buffer: [4096]u8 = undefined; + try httpResponseSendFile(request, archive_file, &send_buffer); + return; + }, + }, + .POST => { + logger.err().logf("{} tried to invoke our RPC", .{conn_address}); + return try request.respond("RPCs are not yet implemented", .{ + .status = .service_unavailable, + .keep_alive = false, + }); + }, + else => {}, + } + + logger.err().logf( + "{} made an unrecognized request '{} {s}'", + .{ conn_address, methodFmt(request.head.method), request.head.target }, + ); + try request.respond("", .{ + .status = .not_found, + .keep_alive = false, + }); +} + +const HttpResponseSendFileError = + std.fs.File.GetSeekPosError || + std.fs.File.SeekError || + std.http.Server.Response.WriteError || + std.fs.File.ReadError; + +fn httpResponseSendFile( + request: *std.http.Server.Request, + archive_file: std.fs.File, + send_buffer: []u8, +) HttpResponseSendFileError!void { + const archive_len = try archive_file.getEndPos(); + + var response = request.respondStreaming(.{ + .send_buffer = send_buffer, + .content_length = archive_len, + }); + const writer = sig.utils.io.narrowAnyWriter( + response.writer(), + std.http.Server.Response.WriteError, + ); + + const Fifo = std.fifo.LinearFifo(u8, .{ .Static = 1 }); + var fifo: Fifo = Fifo.init(); + try archive_file.seekTo(0); + try fifo.pump(archive_file.reader(), writer); + + try response.end(); +} + +fn methodFmt(method: std.http.Method) MethodFmt { + return .{ .method = method }; +} + +const MethodFmt = struct { + method: std.http.Method, + pub fn format( + fmt: MethodFmt, + comptime fmt_str: []const u8, + fmt_options: std.fmt.FormatOptions, + writer: anytype, + ) @TypeOf(writer).Error!void { + _ = fmt_options; + if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); + try fmt.method.write(writer); + } +}; diff --git a/src/utils/fmt.zig b/src/utils/fmt.zig index bfd739a57..e8a3f519b 100644 --- a/src/utils/fmt.zig +++ b/src/utils/fmt.zig @@ -36,7 +36,7 @@ pub fn BoundedSpec(comptime spec: []const u8) type { /// try expectEqual("foo-255".len, boundedLenValue("{[a]s}-{[b]d}", .{ .a = "foo", .b = 255 })); /// ``` pub inline fn fmtLenValue(comptime args_value: anytype) usize { - comptime return fmtLen(fmt_str, @TypeOf(args_value)); + comptime return fmtLen(@TypeOf(args_value)); } pub fn BoundedArray(comptime Args: type) type {