From 694aa99613acca94e057e41bc839639f93103dea Mon Sep 17 00:00:00 2001 From: InKryption <59504965+InKryption@users.noreply.github.com> Date: Mon, 10 Feb 2025 20:42:37 +0100 Subject: [PATCH] perf(rpc): io_uring integration & redesign (#477) * RPC server: io_uring upgrade Separates the server into two parts: the context, and the work pool; the context contains everything generally needed to run the server, the work pool contains a statically polymorphic implementation for a pool to dispatch the actual work to. In doing this, we also separate certain things out into a few different files. The RPC server context API has been modified slightly to reflect this, and the work pool directly exposed, for now. * Don't use file-as-struct * Run style script, respect line length limit * Improve accept failure handling & update TODOs * Handle potentially failing/cancelling of `accept_multishot` by re-queueing it, based on the `IORING_CQE_F_MORE` flag. * Revise/simplify the queueing logic for the `accept_multishot` SQE. * Resolve the EINTR TODO panics, returning a catch-all error value indicating it as a bad but non-critical error. * Update the `a: ?noreturn` `if (a) |*b|` TODO, adding that it's solved in 0.14; it should be resolved after we update to 0.14. * Unify EAGAIN panic message. * Add TODO to remove hacky-ish workaround * Use `self: Type` convention * A few minor fixups and improvements * Simplify test, make server socket nonblocking On MacOS, on basic WorkPool, this means we now need to manually set the accepted socket's flags to the right things, ie, blocking, as opposed to the server socket's nonblocking mode. Means we also have to handle EAGAIN a bit differently in the io_uring backend, but that's a fine tradeoff. * server.zig -> server/lib.zig * Segregate out the basic backend And re-organize some methods based on that change * Re-organize server module * Update LOGGER_SCOPE * Simplify & improve io_uring backend error handling * De-scope `accept_flags` * Simplify `can_use` for linux cross-compilation * (io_uring) Rework error handling, add timeout Do not exit for *any* errors that are specific to the related connection, simply free them and continue to the next CQE. Specifically in the case of `error.SubmissionQueueFull`, instead of immediately failing, we instead first try to flush the submission queue and then try again to submit; if it fails a second time, that means despite flushing the submission queue, it somehow still failed, so we panic, since this indicates something is *very* wrong. This also eliminates the `pending_cqes_buf`, since there is actually no situation in which `consumeOurCqe` returns an error, and we resume work afterwards - either we process all the received CQEs, or we hard exit - this was already essentially the case before, now it's more obvious. For the main submit, we now wait for at least 1 connection, but we also add a timeout SQE to make it terminate if we don't receive a connection or completion of another task for 1 second; this alleviates the busy loop that was running before. * (io_uring) Remove multishot_accept_submitted Also slightly refactor error sets. Now instead of checking to see if we need to set a flag to re-queue the multishot accept, we just pass in the server context on init and queue it, which now makes sense since the context and workpool are separate. * (io_uring) Simplify new entry creation Also add fix for rebase * Misc fixups * Re-organize alias/import * General restructure * Move more specific functions to the only files they're used. * Move the `serve*` functions outside of `Context`, making them free functions which just accept the context and work pool. * Remove `acceptAndServeConnection`; originally this was required to be able to nicely structure the unit test, and used to be more integrated, however it no longer makes sense as a concept. * Inline `handleRequest` into the basic backend. * Make the `acceptHandled` function, moved into the basic backend, guarantee the specified `sync` behavior, and inline `have_accept4`. * Appropriately re-export the relevant parts of the server API. * Added top level doc comments. * Re-oorganize loggers & scopes * Refactor `build_options` imports * Add `no_network_tests` build option And disable the rpc server test when it is enabled * Update circleci with `-Dno-network-tests` --- .circleci/config.yml | 5 +- build.zig | 2 + src/accountsdb/db.zig | 1 + src/accountsdb/snapshots.zig | 11 +- src/ledger/blockstore.zig | 4 +- src/ledger/cleanup_service.zig | 3 +- src/ledger/database/hashmap.zig | 1 - src/ledger/database/rocksdb.zig | 3 +- src/ledger/fuzz.zig | 5 +- src/rpc/lib.zig | 3 +- src/rpc/server.zig | 559 -------------------- src/rpc/server/basic.zig | 194 +++++++ src/rpc/server/connection.zig | 180 +++++++ src/rpc/server/lib.zig | 25 + src/rpc/server/linux_io_uring.zig | 852 ++++++++++++++++++++++++++++++ src/rpc/server/requests.zig | 182 +++++++ src/rpc/server/server.zig | 295 +++++++++++ src/sig.zig | 1 + src/utils/fmt.zig | 2 +- src/utils/io.zig | 84 +++ 20 files changed, 1833 insertions(+), 579 deletions(-) delete mode 100644 src/rpc/server.zig create mode 100644 src/rpc/server/basic.zig create mode 100644 src/rpc/server/connection.zig create mode 100644 src/rpc/server/lib.zig create mode 100644 src/rpc/server/linux_io_uring.zig create mode 100644 src/rpc/server/requests.zig create mode 100644 src/rpc/server/server.zig diff --git a/.circleci/config.yml b/.circleci/config.yml index 04da0a572..2edd188e1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -153,7 +153,8 @@ jobs: key: linux-x86_64-0.13.0-{{ checksum "build.zig.zon" }}-v6 - run: name: Test - command: workspace/zig/zig build test -Dcpu=x86_64_v3 -Denable-tsan=true --color off --summary all + # Disable network-accessing tests for this job, which behave badly on circleci + command: workspace/zig/zig build test -Dcpu=x86_64_v3 -Denable-tsan=true -Dno-network-tests --color off --summary all test_kcov_linux: executor: linux-executor @@ -165,7 +166,7 @@ jobs: key: linux-x86_64-0.13.0-{{ checksum "build.zig.zon" }}-v6 - run: name: Build - command: workspace/zig/zig build test -Dcpu=x86_64_v3 -Denable-tsan=false -Dno-run --summary all + command: workspace/zig/zig build test -Dcpu=x86_64_v3 -Denable-tsan=false -Dno-run -Dno-network-tests --summary all - run: name: Test and Collect command: | diff --git a/build.zig b/build.zig index add14cd92..c9e886b52 100644 --- a/build.zig +++ b/build.zig @@ -20,10 +20,12 @@ pub fn build(b: *Build) void { \\Don't install any of the binaries implied by the specified steps, only run them. \\Use in conjunction with 'no-run' to avoid running as well. ) orelse false; + const no_network_tests = b.option(bool, "no-network-tests", "Do not run any tests that depend on the network.") orelse false; // Build options const build_options = b.addOptions(); build_options.addOption(BlockstoreDB, "blockstore_db", blockstore_db); + build_options.addOption(bool, "no_network_tests", no_network_tests); // CLI build steps const install_step = b.getInstallStep(); diff --git a/src/accountsdb/db.zig b/src/accountsdb/db.zig index 55ad0f827..9f27d0d91 100644 --- a/src/accountsdb/db.zig +++ b/src/accountsdb/db.zig @@ -3350,6 +3350,7 @@ test "testWriteSnapshot" { ); } +/// Unpacks the snapshots from `sig.TEST_DATA_DIR`. pub fn findAndUnpackTestSnapshots( n_threads: usize, /// The directory into which the snapshots are unpacked. diff --git a/src/accountsdb/snapshots.zig b/src/accountsdb/snapshots.zig index 1c8657903..125869872 100644 --- a/src/accountsdb/snapshots.zig +++ b/src/accountsdb/snapshots.zig @@ -2215,7 +2215,7 @@ pub const FullSnapshotFileInfo = struct { slot: Slot, hash: Hash, - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("snapshot-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .slot = std.math.maxInt(Slot), @@ -2341,7 +2341,7 @@ pub const IncrementalSnapshotFileInfo = struct { }; } - const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); + pub const SnapshotArchiveNameFmtSpec = sig.utils.fmt.BoundedSpec("incremental-snapshot-{[base_slot]d}-{[slot]d}-{[hash]s}.tar.zst"); pub const SnapshotArchiveNameStr = SnapshotArchiveNameFmtSpec.BoundedArrayValue(.{ .base_slot = std.math.maxInt(Slot), @@ -2486,15 +2486,16 @@ pub const SnapshotFiles = struct { full: FullSnapshotFileInfo, incremental_info: ?SlotAndHash, - pub fn incremental(snapshot_files: SnapshotFiles) ?IncrementalSnapshotFileInfo { - const inc_info = snapshot_files.incremental_info orelse return null; + pub fn incremental(self: SnapshotFiles) ?IncrementalSnapshotFileInfo { + const inc_info = self.incremental_info orelse return null; return .{ - .base_slot = snapshot_files.full.slot, + .base_slot = self.full.slot, .slot = inc_info.slot, .hash = inc_info.hash, }; } + /// Asserts that `if (maybe_incremental_info) |inc| inc.base_slot == full_info.slot`. pub fn fromFileInfos( full_info: FullSnapshotFileInfo, maybe_incremental_info: ?IncrementalSnapshotFileInfo, diff --git a/src/ledger/blockstore.zig b/src/ledger/blockstore.zig index 1ca184014..01fc27945 100644 --- a/src/ledger/blockstore.zig +++ b/src/ledger/blockstore.zig @@ -1,7 +1,7 @@ -const build_options = @import("build-options"); +const sig = @import("../sig.zig"); const ledger = @import("lib.zig"); -pub const BlockstoreDB = switch (build_options.blockstore_db) { +pub const BlockstoreDB = switch (sig.build_options.blockstore_db) { .rocksdb => ledger.database.RocksDB(&ledger.schema.list), .hashmap => ledger.database.SharedHashMapDB(&ledger.schema.list), }; diff --git a/src/ledger/cleanup_service.zig b/src/ledger/cleanup_service.zig index f7e9563f2..929a552cf 100644 --- a/src/ledger/cleanup_service.zig +++ b/src/ledger/cleanup_service.zig @@ -1,4 +1,3 @@ -const build_options = @import("build-options"); const std = @import("std"); const sig = @import("../sig.zig"); const ledger = @import("lib.zig"); @@ -461,7 +460,7 @@ test "findSlotsToClean" { } // When implementation is rocksdb, we need to flush memtable to disk to be able to assert. // We do that by deiniting the current db, which triggers the flushing. - if (build_options.blockstore_db == .rocksdb) { + if (sig.build_options.blockstore_db == .rocksdb) { db.deinit(); db = try TestDB.reuseBlockstore(@src()); reader.db = db; diff --git a/src/ledger/database/hashmap.zig b/src/ledger/database/hashmap.zig index 8c629a31f..d14e4f3e6 100644 --- a/src/ledger/database/hashmap.zig +++ b/src/ledger/database/hashmap.zig @@ -1,7 +1,6 @@ const std = @import("std"); const sig = @import("../../sig.zig"); const database = @import("lib.zig"); -const build_options = @import("build-options"); const Allocator = std.mem.Allocator; const RwLock = std.Thread.RwLock; diff --git a/src/ledger/database/rocksdb.zig b/src/ledger/database/rocksdb.zig index bf572cfd2..8bfb65a71 100644 --- a/src/ledger/database/rocksdb.zig +++ b/src/ledger/database/rocksdb.zig @@ -2,7 +2,6 @@ const std = @import("std"); const rocks = @import("rocksdb"); const sig = @import("../../sig.zig"); const database = @import("lib.zig"); -const build_options = @import("build-options"); const Allocator = std.mem.Allocator; @@ -348,7 +347,7 @@ fn callRocks(logger: ScopedLogger(LOG_SCOPE), comptime func: anytype, args: anyt } comptime { - if (build_options.blockstore_db == .rocksdb) { + if (sig.build_options.blockstore_db == .rocksdb) { _ = &database.interface.testDatabase(RocksDB); } } diff --git a/src/ledger/fuzz.zig b/src/ledger/fuzz.zig index 1cf7654fd..64c6d4eef 100644 --- a/src/ledger/fuzz.zig +++ b/src/ledger/fuzz.zig @@ -1,6 +1,5 @@ const std = @import("std"); const sig = @import("../sig.zig"); -const build_options = @import("build-options"); const ledger = @import("lib.zig"); const ColumnFamily = sig.ledger.database.ColumnFamily; @@ -19,7 +18,7 @@ const cf1 = ColumnFamily{ var executed_actions = std.AutoHashMap(Actions, void).init(allocator); -pub const BlockstoreDB = switch (build_options.blockstore_db) { +pub const BlockstoreDB = switch (sig.build_options.blockstore_db) { .rocksdb => ledger.database.RocksDB(&.{cf1}), .hashmap => ledger.database.SharedHashMapDB(&.{cf1}), }; @@ -198,7 +197,7 @@ fn dbCount( try executed_actions.put(Actions.count, {}); // TODO Fix why changes are not reflected in count with rocksdb implementation, // but it does with hashmap. - if (build_options.blockstore_db == .rocksdb) { + if (sig.build_options.blockstore_db == .rocksdb) { return; } diff --git a/src/rpc/lib.zig b/src/rpc/lib.zig index 0b5fdb86c..5490bf819 100644 --- a/src/rpc/lib.zig +++ b/src/rpc/lib.zig @@ -1,12 +1,11 @@ pub const client = @import("client.zig"); -pub const server = @import("server.zig"); +pub const server = @import("server/lib.zig"); pub const request = @import("request.zig"); pub const response = @import("response.zig"); pub const types = @import("types.zig"); pub const Client = client.Client; -pub const Server = server.Server; pub const Request = request.Request; pub const Response = response.Response; diff --git a/src/rpc/server.zig b/src/rpc/server.zig deleted file mode 100644 index c320e3900..000000000 --- a/src/rpc/server.zig +++ /dev/null @@ -1,559 +0,0 @@ -const std = @import("std"); -const sig = @import("../sig.zig"); - -const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; -const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; -const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; -const ThreadPool = sig.sync.ThreadPool; - -const LOGGER_SCOPE = "rpc.Server"; -const ScopedLogger = sig.trace.log.ScopedLogger(LOGGER_SCOPE); - -pub const Server = struct { - //! Basic usage: - //! ```zig - //! var server = try Server.init(.{...}); - //! defer server.joinDeinit(); - //! - //! try server.serveSpawnDetached(); // or `.serveDirect`, if the caller can block or is managing the separate thread themselves. - //! ``` - - allocator: std.mem.Allocator, - logger: ScopedLogger, - - snapshot_dir: std.fs.Dir, - latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), - - /// Wait group for all currently running tasks, used to wait for - /// all of them to finish before deinitializing. - wait_group: std.Thread.WaitGroup, - thread_pool: *ThreadPool, - - /// Must not be mutated. - read_buffer_size: usize, - tcp: std.net.Server, - - pub const MIN_READ_BUFFER_SIZE = 256; - - /// The returned result must be pinned to a memory location before calling any methods. - pub fn init(params: struct { - /// Must be a thread-safe allocator. - allocator: std.mem.Allocator, - logger: sig.trace.Logger, - - /// Not closed by the `Server`, but must live at least as long as it. - snapshot_dir: std.fs.Dir, - /// Should reflect the latest generated snapshot eligible for propagation at any - /// given time with respect to the contents of the specified `snapshot_dir`. - latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), - - thread_pool: *ThreadPool, - - /// The size for the read buffer allocated to every request. - /// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. - read_buffer_size: u32, - /// The socket address to listen on for incoming HTTP and/or RPC requests. - socket_addr: std.net.Address, - }) std.net.Address.ListenError!Server { - var tcp_server = try params.socket_addr.listen(.{ - // NOTE: ideally we would be doing this nonblockingly, however this doesn't work properly on mac, - // so for testing purposes we can't test the `serve` functionality directly. - .force_nonblocking = false, - }); - errdefer tcp_server.deinit(); - - return .{ - .allocator = params.allocator, - .logger = params.logger.withScope(LOGGER_SCOPE), - - .snapshot_dir = params.snapshot_dir, - .latest_snapshot_gen_info = params.latest_snapshot_gen_info, - - .wait_group = .{}, - .thread_pool = params.thread_pool, - - .read_buffer_size = @max(params.read_buffer_size, MIN_READ_BUFFER_SIZE), - .tcp = tcp_server, - }; - } - - /// Blocks until all tasks are completed, and then closes the server. - /// Does not force the server to exit. - pub fn joinDeinit(server: *Server) void { - server.wait_group.wait(); - server.tcp.deinit(); - } - - /// Spawn the serve loop as a separate thread. - pub fn serveSpawn( - server: *Server, - exit: *std.atomic.Value(bool), - ) std.Thread.SpawnError!std.Thread { - return std.Thread.spawn(.{}, serve, .{ server, exit }); - } - - /// Calls `acceptAndServeConnection` in a loop until `exit.load(.acquire)`. - pub fn serve( - server: *Server, - exit: *std.atomic.Value(bool), - ) AcceptAndServeConnectionError!void { - while (!exit.load(.acquire)) { - try server.acceptAndServeConnection(); - } - } - - pub const AcceptAndServeConnectionError = - std.mem.Allocator.Error || - std.http.Server.ReceiveHeadError || - AcceptConnectionError; - - pub fn acceptAndServeConnection(server: *Server) AcceptAndServeConnectionError!void { - const conn = (try acceptConnection(&server.tcp, server.logger)).?; - errdefer conn.stream.close(); - - server.wait_group.start(); - errdefer server.wait_group.finish(); - - const new_hct = try HandleConnectionTask.createAndReceiveHead(server, conn); - errdefer new_hct.destroyAndClose(); - - server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task)); - } -}; - -const HandleConnectionTask = struct { - task: ThreadPool.Task, - server: *Server, - http_server: std.http.Server, - request: std.http.Server.Request, - - fn createAndReceiveHead( - server: *Server, - conn: std.net.Server.Connection, - ) (std.http.Server.ReceiveHeadError || std.mem.Allocator.Error)!*HandleConnectionTask { - const allocator = server.allocator; - - const hct_buf_align = @alignOf(HandleConnectionTask); - const hct_buf_size = initBufferSize(server.read_buffer_size); - - const hct_buffer = try allocator.alignedAlloc(u8, hct_buf_align, hct_buf_size); - errdefer server.allocator.free(hct_buffer); - - const hct: *HandleConnectionTask = std.mem.bytesAsValue( - HandleConnectionTask, - hct_buffer[0..@sizeOf(HandleConnectionTask)], - ); - hct.* = .{ - .task = .{ .callback = callback }, - .server = server, - .http_server = std.http.Server.init(conn, getReadBuffer(server.read_buffer_size, hct)), - .request = try hct.http_server.receiveHead(), - }; - - return hct; - } - - /// Does not release the connection. - fn destroyAndClose(hct: *HandleConnectionTask) void { - const allocator = hct.server.allocator; - - const full_buffer = getFullBuffer(hct.server.read_buffer_size, hct); - defer allocator.free(full_buffer); - - const connection = hct.http_server.connection; - defer connection.stream.close(); - } - - fn initBufferSize(read_buffer_size: usize) usize { - return @sizeOf(HandleConnectionTask) + read_buffer_size; - } - - fn getFullBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []align(@alignOf(HandleConnectionTask)) u8 { - const ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct); - return ptr[0..initBufferSize(read_buffer_size)]; - } - - fn getReadBuffer( - read_buffer_size: usize, - hct: *HandleConnectionTask, - ) []u8 { - return getFullBuffer(read_buffer_size, hct)[@sizeOf(HandleConnectionTask)..]; - } - - fn callback(task: *ThreadPool.Task) void { - const hct: *HandleConnectionTask = @fieldParentPtr("task", task); - defer hct.destroyAndClose(); - - const server = hct.server; - const logger = server.logger; - - const wait_group = &server.wait_group; - defer wait_group.finish(); - - handleRequest( - logger, - &hct.request, - server.snapshot_dir, - server.latest_snapshot_gen_info, - ) catch |err| { - if (@errorReturnTrace()) |stack_trace| { - logger.err().logf("{s}\n{}", .{ @errorName(err), stack_trace }); - } else { - logger.err().logf("{s}", .{@errorName(err)}); - } - }; - } -}; - -fn handleRequest( - logger: ScopedLogger, - request: *std.http.Server.Request, - snapshot_dir: std.fs.Dir, - latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), -) !void { - const conn_address = request.server.connection.address; - - logger.info().logf("Responding to request from {}: {} {s}", .{ - conn_address, methodFmt(request.head.method), request.head.target, - }); - switch (request.head.method) { - .POST => { - logger.err().logf("{} tried to invoke our RPC", .{conn_address}); - return try request.respond("RPCs are not yet implemented", .{ - .status = .service_unavailable, - .keep_alive = false, - }); - }, - .GET => get_blk: { - if (!std.mem.startsWith(u8, request.head.target, "/")) break :get_blk; - const path = request.head.target[1..]; - - // we hold the lock for the entirety of this process in order to prevent - // the snapshot generation process from deleting the associated snapshot. - const maybe_latest_snapshot_gen_info, // - var latest_snapshot_info_lg // - = latest_snapshot_gen_info_rw.readWithLock(); - defer latest_snapshot_info_lg.unlock(); - - const full_info: ?FullSnapshotFileInfo, // - const inc_info: ?IncrementalSnapshotFileInfo // - = blk: { - const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse - break :blk .{ null, null }; - const latest_full = latest_snapshot_gen_info.full; - const full_info: FullSnapshotFileInfo = .{ - .slot = latest_full.slot, - .hash = latest_full.hash, - }; - const latest_incremental = latest_snapshot_gen_info.inc orelse - break :blk .{ full_info, null }; - const inc_info: IncrementalSnapshotFileInfo = .{ - .base_slot = latest_full.slot, - .slot = latest_incremental.slot, - .hash = latest_incremental.hash, - }; - break :blk .{ full_info, inc_info }; - }; - - logger.debug().logf("Available full: {?s}", .{ - if (full_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - logger.debug().logf("Available inc: {?s}", .{ - if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, - }); - - if (full_info) |full| { - const full_archive_name_bounded = full.snapshotArchiveName(); - const full_archive_name = full_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, full_archive_name)) { - const archive_file = try snapshot_dir.openFile(full_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - - if (inc_info) |inc| { - const inc_archive_name_bounded = inc.snapshotArchiveName(); - const inc_archive_name = inc_archive_name_bounded.constSlice(); - if (std.mem.eql(u8, path, inc_archive_name)) { - const archive_file = try snapshot_dir.openFile(inc_archive_name, .{}); - defer archive_file.close(); - var send_buffer: [4096]u8 = undefined; - try httpResponseSendFile(request, archive_file, &send_buffer); - return; - } - } - }, - else => {}, - } - - logger.err().logf( - "{} made an unrecognized request '{} {s}'", - .{ conn_address, methodFmt(request.head.method), request.head.target }, - ); - try request.respond("", .{ - .status = .not_found, - .keep_alive = false, - }); -} - -fn httpResponseSendFile( - request: *std.http.Server.Request, - archive_file: std.fs.File, - send_buffer: []u8, -) !void { - const archive_len = try archive_file.getEndPos(); - - var response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .content_length = archive_len, - }); - const writer = sig.utils.io.narrowAnyWriter( - response.writer(), - std.http.Server.Response.WriteError, - ); - - const Fifo = std.fifo.LinearFifo(u8, .{ .Static = 1 }); - var fifo: Fifo = Fifo.init(); - try archive_file.seekTo(0); - try fifo.pump(archive_file.reader(), writer); - - try response.end(); -} - -const AcceptConnectionError = error{ - ProcessFdQuotaExceeded, - SystemFdQuotaExceeded, - SystemResources, - ProtocolFailure, - BlockedByFirewall, - NetworkSubsystemFailed, -} || std.posix.UnexpectedError; - -fn acceptConnection( - tcp_server: *std.net.Server, - logger: ScopedLogger, -) AcceptConnectionError!?std.net.Server.Connection { - const conn = tcp_server.accept() catch |err| switch (err) { - error.Unexpected, - => |e| return e, - - error.ProcessFdQuotaExceeded, - error.SystemFdQuotaExceeded, - error.SystemResources, - error.ProtocolFailure, - error.BlockedByFirewall, - error.NetworkSubsystemFailed, - => |e| return e, - - error.FileDescriptorNotASocket, - error.SocketNotListening, - error.OperationNotSupported, - => @panic("Improperly initialized server."), - - error.WouldBlock, - => return null, - - error.ConnectionResetByPeer, - error.ConnectionAborted, - => |e| { - logger.warn().logf("{}", .{e}); - return null; - }, - }; - - return conn; -} - -fn methodFmt(method: std.http.Method) MethodFmt { - return .{ .method = method }; -} - -const MethodFmt = struct { - method: std.http.Method, - pub fn format( - fmt: MethodFmt, - comptime fmt_str: []const u8, - fmt_options: std.fmt.FormatOptions, - writer: anytype, - ) @TypeOf(writer).Error!void { - _ = fmt_options; - if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); - try fmt.method.write(writer); - } -}; - -test Server { - const allocator = std.testing.allocator; - - var prng = std.Random.DefaultPrng.init(0); - const random = prng.random(); - - // const logger: sig.trace.Logger = .{ .direct_print = .{ .max_level = .trace } }; - const logger: sig.trace.Logger = .noop; - - var test_data_dir = try std.fs.cwd().openDir("data/test-data", .{ .iterate = true }); - defer test_data_dir.close(); - - var tmp_dir_root = std.testing.tmpDir(.{}); - defer tmp_dir_root.cleanup(); - const tmp_dir = tmp_dir_root.dir; - - var snap_dir = try tmp_dir.makeOpenPath("snapshot", .{ .iterate = true }); - defer snap_dir.close(); - - const SnapshotFiles = sig.accounts_db.snapshots.SnapshotFiles; - const snap_files = try SnapshotFiles.find(allocator, test_data_dir); - - const full_snap_name_bounded = snap_files.full.snapshotArchiveName(); - const maybe_inc_snap_name_bounded = - if (snap_files.incremental()) |inc| inc.snapshotArchiveName() else null; - - { - const full_snap_name = full_snap_name_bounded.constSlice(); - - try test_data_dir.copyFile(full_snap_name, snap_dir, full_snap_name, .{}); - const full_snap_file = try snap_dir.openFile(full_snap_name, .{}); - defer full_snap_file.close(); - - const unpack = sig.accounts_db.snapshots.parallelUnpackZstdTarBall; - try unpack(allocator, logger, full_snap_file, snap_dir, 1, true); - } - - if (maybe_inc_snap_name_bounded) |inc_snap_name_bounded| { - const inc_snap_name = inc_snap_name_bounded.constSlice(); - - try test_data_dir.copyFile(inc_snap_name, snap_dir, inc_snap_name, .{}); - const inc_snap_file = try snap_dir.openFile(inc_snap_name, .{}); - defer inc_snap_file.close(); - - const unpack = sig.accounts_db.snapshots.parallelUnpackZstdTarBall; - try unpack(allocator, logger, inc_snap_file, snap_dir, 1, false); - } - - var accountsdb = try sig.accounts_db.AccountsDB.init(.{ - .allocator = allocator, - .logger = logger, - .snapshot_dir = snap_dir, - .geyser_writer = null, - .gossip_view = null, - .index_allocation = .ram, - .number_of_index_shards = 4, - .lru_size = null, - }); - defer accountsdb.deinit(); - - { - const FullAndIncrementalManifest = sig.accounts_db.snapshots.FullAndIncrementalManifest; - const all_snap_fields = try FullAndIncrementalManifest.fromFiles( - allocator, - logger, - snap_dir, - snap_files, - ); - defer all_snap_fields.deinit(allocator); - - (try accountsdb.loadWithDefaults( - allocator, - all_snap_fields, - 1, - true, - 300, - false, - false, - )).deinit(allocator); - } - - var thread_pool = sig.sync.ThreadPool.init(.{ .max_threads = 1 }); - defer { - thread_pool.shutdown(); - thread_pool.deinit(); - } - - const rpc_port = random.intRangeLessThan(u16, 8_000, 10_000); - var rpc_server = try Server.init(.{ - .allocator = allocator, - .logger = logger, - .snapshot_dir = snap_dir, - .latest_snapshot_gen_info = &accountsdb.latest_snapshot_gen_info, - .thread_pool = &thread_pool, - .socket_addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, rpc_port), - .read_buffer_size = 4096, - }); - defer rpc_server.joinDeinit(); - - try testExpectSnapshotResponse( - allocator, - &rpc_server, - &full_snap_name_bounded, - snap_dir, - ); - - if (maybe_inc_snap_name_bounded) |inc_snap_name_bounded| { - try testExpectSnapshotResponse( - allocator, - &rpc_server, - &inc_snap_name_bounded, - snap_dir, - ); - } -} - -fn testExpectSnapshotResponse( - allocator: std.mem.Allocator, - rpc_server: *Server, - snap_name_bounded: anytype, - snap_dir: std.fs.Dir, -) !void { - const rpc_port = rpc_server.tcp.listen_address.getPort(); - const snap_url_str_bounded = sig.utils.fmt.boundedFmt( - "http://localhost:{d}/{s}", - .{ rpc_port, sig.utils.fmt.boundedString(snap_name_bounded) }, - ); - const snap_url = try std.Uri.parse(snap_url_str_bounded.constSlice()); - - const serve_thread = try std.Thread.spawn(.{}, Server.acceptAndServeConnection, .{rpc_server}); - const actual_data = try testDownloadSelfSnapshot(allocator, snap_url); - defer allocator.free(actual_data); - serve_thread.join(); - - const snap_name = snap_name_bounded.constSlice(); - - const expected_data = try snap_dir.readFileAlloc(allocator, snap_name, 1 << 32); - defer allocator.free(expected_data); - - try std.testing.expectEqualSlices(u8, expected_data, actual_data); -} - -fn testDownloadSelfSnapshot( - allocator: std.mem.Allocator, - snap_url: std.Uri, -) ![]const u8 { - var client: std.http.Client = .{ .allocator = allocator }; - defer client.deinit(); - - var server_header_buffer: [4096 * 16]u8 = undefined; - var request = try client.open(.GET, snap_url, .{ - .server_header_buffer = &server_header_buffer, - }); - defer request.deinit(); - - try request.send(); - try request.finish(); - try request.wait(); - - const content_len = request.response.content_length.?; - const reader = request.reader(); - - const response_content = try reader.readAllAlloc(allocator, 1 << 32); - errdefer allocator.free(response_content); - - try std.testing.expectEqual(content_len, response_content.len); - - return response_content; -} diff --git a/src/rpc/server/basic.zig b/src/rpc/server/basic.zig new file mode 100644 index 000000000..da7422756 --- /dev/null +++ b/src/rpc/server/basic.zig @@ -0,0 +1,194 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); + +const server = @import("server.zig"); +const requests = server.requests; +const connection = server.connection; + +const LOGGER_SCOPE = "rpc.server.basic"; + +pub const AcceptAndServeConnectionError = + AcceptHandledError || + SetSocketSyncError || + std.http.Server.ReceiveHeadError || + std.http.Server.Response.WriteError || + std.mem.Allocator.Error || + std.fs.File.GetSeekPosError || + std.fs.File.OpenError || + std.fs.File.ReadError; + +pub fn acceptAndServeConnection(server_ctx: *server.Context) !void { + const logger = server_ctx.logger.withScope(LOGGER_SCOPE); + + const conn = acceptHandled( + server_ctx.tcp, + .blocking, + ) catch |err| switch (err) { + error.WouldBlock => return, + else => |e| return e, + }; + defer conn.stream.close(); + + server_ctx.wait_group.start(); + defer server_ctx.wait_group.finish(); + + const buffer = try server_ctx.allocator.alloc(u8, server_ctx.read_buffer_size); + defer server_ctx.allocator.free(buffer); + + var http_server = std.http.Server.init(conn, buffer); + var request = try http_server.receiveHead(); + + const conn_address = request.server.connection.address; + logger.info().logf("Responding to request from {}: {} {s}", .{ + conn_address, requests.methodFmt(request.head.method), request.head.target, + }); + + switch (request.head.method) { + .HEAD, .GET => switch (requests.getRequestTargetResolve( + logger.unscoped(), + request.head.target, + server_ctx.latest_snapshot_gen_info, + )) { + inline .full_snapshot, .inc_snapshot => |pair| { + const snap_info, var full_info_lg = pair; + defer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const archive_file = try server_ctx.snapshot_dir.openFile(archive_name, .{}); + defer archive_file.close(); + + const archive_len = try archive_file.getEndPos(); + + var send_buffer: [4096]u8 = undefined; + var response = request.respondStreaming(.{ + .send_buffer = &send_buffer, + .content_length = archive_len, + .respond_options = .{}, + }); + + if (!response.elide_body) { + // use a length which is still a multiple of 2, greater than the send_buffer length, + // in order to almost always force the http server method to flush, instead of + // pointlessly copying data into the send buffer. + const read_buffer_len = comptime std.mem.alignForward( + usize, + send_buffer.len + 1, + 2, + ); + var read_buffer: [read_buffer_len]u8 = undefined; + + while (true) { + const file_data_len = try archive_file.read(&read_buffer); + if (file_data_len == 0) break; + const file_data = read_buffer[0..file_data_len]; + try response.writeAll(file_data); + } + } else { + std.debug.assert(response.transfer_encoding.content_length == archive_len); + // NOTE: in order to avoid needing to actually spend time writing the response body, + // just trick the API into thinking we already wrote the entire thing by setting this + // to 0. + response.transfer_encoding.content_length = 0; + } + + try response.end(); + return; + }, + .unrecognized => {}, + }, + .POST => { + logger.err().logf("{} tried to invoke our RPC", .{conn_address}); + return try request.respond("RPCs are not yet implemented", .{ + .status = .service_unavailable, + .keep_alive = false, + }); + }, + else => {}, + } + + logger.err().logf( + "{} made an unrecognized request '{} {s}'", + .{ conn_address, requests.methodFmt(request.head.method), request.head.target }, + ); + try request.respond("", .{ + .status = .not_found, + .keep_alive = false, + }); +} + +const SyncKind = enum { blocking, nonblocking }; + +const AcceptHandledError = + error{ + ConnectionAborted, + ProtocolFailure, + WouldBlock, +} || connection.HandleAcceptError || + SetSocketSyncError; + +fn acceptHandled( + tcp_server: std.net.Server, + sync: SyncKind, +) AcceptHandledError!std.net.Server.Connection { + var accept_flags: u32 = std.posix.SOCK.CLOEXEC; + accept_flags |= switch (sync) { + .blocking => 0, + .nonblocking => std.posix.SOCK.NONBLOCK, + }; + + // When this is false, it means we can't apply flags to + // the accepted socket, and we'll have to ensure that the + // relevant flags are enabled/disabled after acceptance. + const have_accept4 = comptime !builtin.target.isDarwin(); + + const conn: std.net.Server.Connection = while (true) { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + const rc = if (have_accept4) + std.posix.system.accept4(tcp_server.stream.handle, &addr.any, &addr_len, accept_flags) + else + std.posix.system.accept(tcp_server.stream.handle, &addr.any, &addr_len); + + break switch (try connection.handleAcceptResult(std.posix.errno(rc))) { + .intr => continue, + .conn_aborted => return error.ConnectionAborted, + .proto_fail => return error.ProtocolFailure, + .again => return error.WouldBlock, + .success => .{ + .stream = .{ .handle = rc }, + .address = addr, + }, + }; + }; + + if (!have_accept4) { + try setSocketSync(conn.stream.handle, sync); + } + + return conn; +} + +const SetSocketSyncError = std.posix.FcntlError; + +/// Ensure the socket is set to be blocking or nonblocking. +/// Useful in tandem with the situation described by `HAVE_ACCEPT4`. +fn setSocketSync( + socket: std.posix.socket_t, + sync: SyncKind, +) SetSocketSyncError!void { + const FlagsInt = @typeInfo(std.posix.O).Struct.backing_integer.?; + var flags_int: FlagsInt = @intCast(try std.posix.fcntl(socket, std.posix.F.GETFL, 0)); + const flags = std.mem.bytesAsValue(std.posix.O, std.mem.asBytes(&flags_int)); + + const nonblock_wanted = switch (sync) { + .blocking => false, + .nonblocking => true, + }; + if (flags.NONBLOCK != nonblock_wanted) { + flags.NONBLOCK = nonblock_wanted; + _ = try std.posix.fcntl(socket, std.posix.F.SETFL, flags_int); + } +} diff --git a/src/rpc/server/connection.zig b/src/rpc/server/connection.zig new file mode 100644 index 000000000..450475749 --- /dev/null +++ b/src/rpc/server/connection.zig @@ -0,0 +1,180 @@ +const builtin = @import("builtin"); +const std = @import("std"); + +pub const HandleAcceptError = error{ + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + BlockedByFirewall, +} || std.posix.UnexpectedError; + +pub const HandleAcceptResult = enum { + success, + intr, + again, + conn_aborted, + proto_fail, +}; + +/// Resembles the error handling of `std.posix.accept`. +pub fn handleAcceptResult( + /// Must be the result of `std.posix.accept` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleAcceptError!HandleAcceptResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNABORTED => .conn_aborted, + .PROTO => .proto_fail, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .NOTSOCK, // don't call accept on a non-socket + .OPNOTSUPP, // socket must support accept + .INVAL, // socket must be listening + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .MFILE => return error.ProcessFdQuotaExceeded, + .NFILE => return error.SystemFdQuotaExceeded, + .NOBUFS => return error.SystemResources, + .NOMEM => return error.SystemResources, + .PERM => return error.BlockedByFirewall, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleRecvError = error{ + SystemResources, +} || std.posix.UnexpectedError; + +pub const HandleRecvResult = enum { + success, + intr, + again, + conn_refused, + conn_reset, + timed_out, +}; + +/// Resembles the error handling of `std.posix.recv`. +pub fn handleRecvResult( + /// Must be the result of `std.posix.recv` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleRecvError!HandleRecvResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNREFUSED => .conn_refused, + .CONNRESET => .conn_reset, + .TIMEDOUT => .timed_out, + + .BADF, // always a race condition + .FAULT, // don't address bad memory + .INVAL, // socket must be listening + .NOTSOCK, // don't call accept on a non-socket + .NOTCONN, // we should always be connected + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .NOMEM => return error.SystemResources, + else => |err| return std.posix.unexpectedErrno(err), + }; +} + +pub const HandleSendError = error{ + AccessDenied, + FastOpenAlreadyInProgress, + MessageTooBig, + SystemResources, + NetworkSubsystemFailed, +} || std.posix.UnexpectedError; + +pub const HandleSendResult = enum { + success, + intr, + again, + conn_reset, + broken_pipe, +}; + +pub fn handleSendResult( + /// Must be the result of `std.posix.send` or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleSendError!HandleSendResult { + comptime std.debug.assert( // + builtin.target.isDarwin() or builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .INTR => .intr, + .AGAIN => .again, + .CONNRESET => .conn_reset, + .PIPE => .broken_pipe, + + .BADF, // always a race condition + .DESTADDRREQ, // The socket is not connection-mode, and no peer address is set. + .FAULT, // An invalid user space address was specified for an argument. + .ISCONN, // connection-mode socket was connected already but a recipient was specified + .NOTSOCK, // The file descriptor sockfd does not refer to a socket. + .OPNOTSUPP, // Some bit in the flags argument is inappropriate for the socket type. + + // these are all reachable through `sendto`, but unreachable through `send`. + .AFNOSUPPORT, + .LOOP, + .NAMETOOLONG, + .NOENT, + .NOTDIR, + .HOSTUNREACH, + .NETUNREACH, + .NOTCONN, + .INVAL, + => |e| std.debug.panic("{s}", .{@tagName(e)}), + + .ACCES => return error.AccessDenied, + .ALREADY => return error.FastOpenAlreadyInProgress, + .MSGSIZE => return error.MessageTooBig, + .NOBUFS, .NOMEM => return error.SystemResources, + .NETDOWN => return error.NetworkSubsystemFailed, + else => |e| std.posix.unexpectedErrno(e), + }; +} + +pub const HandleSpliceError = error{ + SystemResources, +} || std.posix.UnexpectedError; + +pub const HandleSpliceResult = enum { + success, + again, + /// One or both file descriptors are not valid, or do not have proper read-write mode. + bad_file_descriptors, + /// Either off_in or off_out was not NULL, but the corresponding file descriptor refers to a pipe. + bad_fd_offset, + /// Could be one of many reasons, see the manpage for splice. + invalid_splice, +}; + +pub fn handleSpliceResult( + /// Must be the result of calling the `splice` syscall or equivalent (ie io_uring cqe.err()). + rc: std.posix.E, +) HandleSpliceError!HandleSpliceResult { + comptime std.debug.assert( // + builtin.target.os.tag == .linux // + ); + return switch (rc) { + .SUCCESS => .success, + .AGAIN => .again, + .INVAL => .invalid_splice, + .SPIPE => .bad_fd_offset, + .BADF => .bad_file_descriptors, + .NOMEM => return error.SystemResources, + else => |err| std.posix.unexpectedErrno(err), + }; +} diff --git a/src/rpc/server/lib.zig b/src/rpc/server/lib.zig new file mode 100644 index 000000000..7fb70fcf1 --- /dev/null +++ b/src/rpc/server/lib.zig @@ -0,0 +1,25 @@ +//! RPC Server API. +//! +//! The server can be run by calling `serveSpawn`, or `serve`; in +//! order to do this, the caller must first initialize a `Context` +//! to provide the basic state and dependencies required to operate +//! the server, and must also provide a `WorkPool`, initialized to +//! a given backend. + +const server = @import("server.zig"); + +comptime { + _ = server; +} + +pub const MIN_READ_BUFFER_SIZE = server.MIN_READ_BUFFER_SIZE; + +pub const serveSpawn = server.serveSpawn; +pub const serve = server.serve; + +pub const Context = server.Context; +pub const WorkPool = server.WorkPool; + +// backends +pub const basic = server.basic; +pub const LinuxIoUring = server.LinuxIoUring; diff --git a/src/rpc/server/linux_io_uring.zig b/src/rpc/server/linux_io_uring.zig new file mode 100644 index 000000000..f91621b5a --- /dev/null +++ b/src/rpc/server/linux_io_uring.zig @@ -0,0 +1,852 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); +const server = @import("server.zig"); + +const requests = server.requests; +const connection = server.connection; + +const IoUring = std.os.linux.IoUring; + +const LOGGER_SCOPE = "rpc.server.linux_io_uring"; + +pub const LinuxIoUring = struct { + io_uring: IoUring, + + pub const can_use: bool = builtin.os.tag == .linux; + + pub const InitError = IouInitError || GetSqeRetryError; + + // NOTE(ink): constructing the return type as `E!?T`, where `E` and `T` are resolved + // separately seems to help ZLS with understanding the types involved better, which is + // why I've done it like that here. If ZLS gets smarter in the future, you could probably + // inline this into a single branch in the return type expression. + const InitErrOrEmpty = if (!can_use) error{} else InitError; + const InitResultOrNoreturn = if (!can_use) noreturn else LinuxIoUring; + pub fn init( + /// Not stored, only used for some initial SQE preps, the pointer needn't remain stable. + server_ctx: *const server.Context, + ) InitErrOrEmpty!?InitResultOrNoreturn { + if (!can_use) return null; + var io_uring = IoUring.init(4096, 0) catch |err| return switch (err) { + error.SystemOutdated, + error.PermissionDenied, + => return null, + else => |e| e, + }; + errdefer io_uring.deinit(); + + try prepMultishotAccept(&io_uring, server_ctx.tcp); + return .{ .io_uring = io_uring }; + } + + pub fn deinit(self: *LinuxIoUring) void { + self.io_uring.deinit(); + } + + pub const AcceptAndServeConnectionsError = + GetSqeRetryError || + ConsumeOurCqeError || + std.mem.Allocator.Error; + + pub fn acceptAndServeConnections( + self: *LinuxIoUring, + server_ctx: *server.Context, + ) AcceptAndServeConnectionsError!void { + const timeout_ts: std.os.linux.kernel_timespec = comptime .{ + .tv_sec = 1, + .tv_nsec = 0, + }; + + const timeout_sqe = try getSqeRetry(&self.io_uring); + timeout_sqe.prep_timeout(&timeout_ts, 1, 0); + timeout_sqe.user_data = 1; + + _ = try self.io_uring.submit_and_wait(1); + + var pending_cqes_buf: [255]std.os.linux.io_uring_cqe = undefined; + const pending_cqes_count = try self.io_uring.copy_cqes(&pending_cqes_buf, 0); + const cqes_pending = pending_cqes_buf[0..pending_cqes_count]; + + for (cqes_pending) |raw_cqe| { + // NOTE(ink): this is kind of hacky, should try refactoring this to use DOD-like indexes instead of pointers, + // that way we can allocate special static indexes instead of this. + if (raw_cqe.user_data == timeout_sqe.user_data) continue; + const our_cqe = OurCqe.fromCqe(raw_cqe); + try consumeOurCqe(self, server_ctx, our_cqe); + } + } +}; + +fn prepMultishotAccept( + io_uring: *IoUring, + tcp: std.net.Server, +) GetSqeRetryError!void { + const sqe = try getSqeRetry(io_uring); + sqe.prep_multishot_accept(tcp.stream.handle, null, null, std.os.linux.SOCK.CLOEXEC); + sqe.user_data = @bitCast(Entry.ACCEPT); +} + +const ConsumeOurCqeError = + HandleRecvBodyError || + std.mem.Allocator.Error || + connection.HandleAcceptError || + connection.HandleRecvError || + connection.HandleSendError || + connection.HandleSpliceError; + +/// On return, `cqe.user_data` is in an undefined state - this is to say, +/// it has either already been `deinit`ed, or it has been been re-submitted +/// in a new `SQE` and should not be modified; in either scenario, the caller +/// should not interact with it. +fn consumeOurCqe( + liou: *LinuxIoUring, + server_ctx: *server.Context, + cqe: OurCqe, +) ConsumeOurCqeError!void { + const logger = server_ctx.logger.withScope(LOGGER_SCOPE); + + const entry = cqe.user_data; + errdefer entry.deinit(server_ctx.allocator); + + const entry_data: *EntryData = entry.ptr orelse { + // `accept_multishot` cqe + + // we may need to re-submit the `accept_multishot` sqe. + const accept_cancelled = cqe.flags & std.os.linux.IORING_CQE_F_MORE == 0; + if (accept_cancelled) try prepMultishotAccept(&liou.io_uring, server_ctx.tcp); + + switch (try connection.handleAcceptResult(cqe.err())) { + .success => {}, + // just quickly exit; if we need to re-issue, that's already handled above + .intr, + .again, + .conn_aborted, + .proto_fail, + => return, + } + + const stream: std.net.Stream = .{ .handle = cqe.res }; + errdefer stream.close(); + + server_ctx.wait_group.start(); + errdefer server_ctx.wait_group.finish(); + + const buffer = try server_ctx.allocator.alloc(u8, server_ctx.read_buffer_size); + errdefer server_ctx.allocator.free(buffer); + + const data_ptr = try server_ctx.allocator.create(EntryData); + errdefer server_ctx.allocator.destroy(data_ptr); + data_ptr.* = .{ + .buffer = buffer, + .stream = stream, + .state = EntryState.INIT, + }; + + const sqe = try getSqeRetry(&liou.io_uring); + sqe.prep_recv(stream.handle, buffer, 0); + sqe.user_data = @bitCast(Entry{ .ptr = data_ptr }); + return; + }; + errdefer server_ctx.wait_group.finish(); + + const addr_err_logger = logger.err().field( + "address", + // if we fail to getSockName, just print the error in place of the address; + getSocketName(entry_data.stream.handle), + ); + errdefer addr_err_logger.log("Dropping connection"); + + // Panic message for handling `EAGAIN`; we're not using nonblocking sockets at all, + // so it should be impossible to receive that error, or for such an error to be + // triggered just from malicious connections. + const eagain_panic_msg = + "The file/socket should not be in nonblocking mode;" ++ + " server or file/socket configuration error."; + + switch (entry_data.state) { + .recv_head => |*head| { + switch (try connection.handleRecvResult(cqe.err())) { + .success => {}, + + .again => std.debug.panic(eagain_panic_msg, .{}), + + .intr => { + try head.prepRecv(entry, &liou.io_uring); + return; + }, + + .conn_refused, + .conn_reset, + .timed_out, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + + const recv_len: usize = @intCast(cqe.res); + std.debug.assert(head.parser.state != .finished); + + const recv_start = head.end; + const recv_end = recv_start + recv_len; + head.end += head.parser.feed(entry_data.buffer[recv_start..recv_end]); + + if (head.parser.state != .finished) { + std.debug.assert(head.end == recv_end); + if (head.end == entry_data.buffer.len) { + entry.deinit(server_ctx.allocator); + return; + } + + try head.prepRecv(entry, &liou.io_uring); + return; + } + + // copy relevant headers and information out of the buffer, + // so we can use the buffer exclusively for the request body. + const HeadInfo = requests.HeadInfo; + const head_info: HeadInfo = head_info: { + const head_bytes = entry_data.buffer[0..head.end]; + const std_head = std.http.Server.Request.Head.parse(head_bytes) catch |err| { + logger.err().logf("Head parse error: {s}", .{@errorName(err)}); + entry.deinit(server_ctx.allocator); + return; + }; + + // at the time of writing, this always holds true for the result of `Head.parse`. + std.debug.assert(std_head.compression == .none); + break :head_info HeadInfo.parseFromStdHead(std_head) catch |err| { + switch (err) { + error.RequestTargetTooLong => { + logger.err().logf("Request target was too long: '{}'", .{ + std.zig.fmtEscapes(std_head.target), + }); + }, + else => {}, + } + entry.deinit(server_ctx.allocator); + return; + }; + }; + + // ^ we just copied the relevant head info, so we're going to move + // the body content to the start of the buffer. + const content_end = blk: { + const old_content_bytes = entry_data.buffer[head.end..recv_end]; + std.mem.copyForwards( + u8, + entry_data.buffer[0..old_content_bytes.len], + old_content_bytes, + ); + break :blk old_content_bytes.len; + }; + + entry_data.state = .{ .recv_body = .{ + .head_info = head_info, + .content_end = content_end, + } }; + const body = &entry_data.state.recv_body; + handleRecvBody(liou, server_ctx, entry, body) catch |err| { + logger.err().logf("{s}", .{@errorName(err)}); + entry.deinit(server_ctx.allocator); + }; + return; + }, + + .recv_body => |*body| { + switch (try connection.handleRecvResult(cqe.err())) { + .success => {}, + .again => std.debug.panic(eagain_panic_msg, .{}), + .intr => @panic("TODO:"), + .conn_refused, + .conn_reset, + .timed_out, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + + const recv_len: usize = @intCast(cqe.res); + body.content_end += recv_len; + handleRecvBody(liou, server_ctx, entry, body) catch |err| { + logger.err().logf("{s}", .{@errorName(err)}); + entry.deinit(server_ctx.allocator); + }; + return; + }, + + .send_file_head => |*sfh| { + switch (try connection.handleSendResult(cqe.err())) { + .success => {}, + .again => std.debug.panic(eagain_panic_msg, .{}), + .intr => @panic("TODO:"), + .conn_reset, + .broken_pipe, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + const sent_len: usize = @intCast(cqe.res); + sfh.sent_bytes += sent_len; + + switch (try sfh.computeAndMaybePrepSend(entry, &liou.io_uring)) { + .sending_more => return, + .all_sent => switch (sfh.data) { + .file_size => { + entry.deinit(server_ctx.allocator); + server_ctx.wait_group.finish(); + return; + }, + .sfd => |sfd| { + entry_data.state = .{ .send_file_body = .{ + .sfd = sfd, + .spliced_to_pipe = 0, + .spliced_to_socket = 0, + .which = .to_pipe, + } }; + const sfb = &entry_data.state.send_file_body; + try sfb.prepSpliceFileToPipe(entry, &liou.io_uring); + return; + }, + }, + } + }, + + .send_file_body => |*sfb| switch (sfb.which) { + .to_pipe => { + switch (try connection.handleSpliceResult(cqe.err())) { + .success => {}, + .again => std.debug.panic(eagain_panic_msg, .{}), + .bad_file_descriptors, + .bad_fd_offset, + .invalid_splice, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + sfb.spliced_to_pipe += @intCast(cqe.res); + + sfb.which = .to_socket; + try sfb.prepSplicePipeToSocket(entry, &liou.io_uring); + + return; + }, + .to_socket => { + switch (try connection.handleSpliceResult(cqe.err())) { + .success => {}, + .again => std.debug.panic(eagain_panic_msg, .{}), + .bad_file_descriptors, + .bad_fd_offset, + .invalid_splice, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + sfb.spliced_to_socket += @intCast(cqe.res); + + if (sfb.spliced_to_socket < sfb.sfd.file_size) { + sfb.which = .to_pipe; + try sfb.prepSpliceFileToPipe(entry, &liou.io_uring); + } else { + std.debug.assert(sfb.spliced_to_socket == sfb.spliced_to_pipe); + entry.deinit(server_ctx.allocator); + server_ctx.wait_group.finish(); + } + return; + }, + }, + + .send_no_body => |*snb| { + switch (try connection.handleSendResult(cqe.err())) { + .success => {}, + .again => std.debug.panic(eagain_panic_msg, .{}), + .intr => @panic("TODO:"), + .conn_reset, + .broken_pipe, + => { + entry.deinit(server_ctx.allocator); + return; + }, + } + const sent_len: usize = @intCast(cqe.res); + snb.end_index += sent_len; + + if (snb.end_index < snb.head.len) { + try snb.prepSend(entry, &liou.io_uring); + return; + } else std.debug.assert(snb.end_index == snb.head.len); + + entry.deinit(server_ctx.allocator); + server_ctx.wait_group.finish(); + return; + }, + } + + comptime unreachable; +} + +const HandleRecvBodyError = + GetSqeRetryError || + std.fs.Dir.StatFileError || + std.fs.File.OpenError || + std.fs.File.GetSeekPosError || + std.posix.PipeError; + +fn handleRecvBody( + liou: *LinuxIoUring, + server_ctx: *server.Context, + entry: Entry, + body: *EntryState.RecvBody, +) HandleRecvBodyError!void { + const logger = server_ctx.logger.withScope(LOGGER_SCOPE); + + const entry_data = entry.ptr.?; + std.debug.assert(body == &entry_data.state.recv_body); + + if (!body.head_info.method.requestHasBody()) { + if (body.head_info.content_len) |content_len| { + logger.err().logf( + "{} request isn't expected to have a body, but got Content-Length: {d}", + .{ requests.methodFmt(body.head_info.method), content_len }, + ); + } + } + + switch (body.head_info.method) { + .POST => { + entry_data.state = .{ + .send_no_body = EntryState.SendNoBody.initHttStatus( + .@"HTTP/1.0", + .service_unavailable, + ), + }; + const snb = &entry_data.state.send_no_body; + try snb.prepSend(entry, &liou.io_uring); + return; + }, + + inline .HEAD, .GET => |method| switch (requests.getRequestTargetResolve( + logger.unscoped(), + body.head_info.target.constSlice(), + server_ctx.latest_snapshot_gen_info, + )) { + inline .full_snapshot, .inc_snapshot => |pair| { + const sfh_data: EntryState.SendFileHead.Data = switch (method) { + .HEAD => blk: { + const snap_info, var full_info_lg = pair; + defer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const snapshot_dir = server_ctx.snapshot_dir; + const snap_stat = try snapshot_dir.statFile(archive_name); + break :blk .{ .file_size = snap_stat.size }; + }, + .GET => blk: { + const snap_info, var full_info_lg = pair; + errdefer full_info_lg.unlock(); + + const archive_name_bounded = snap_info.snapshotArchiveName(); + const archive_name = archive_name_bounded.constSlice(); + + const snapshot_dir = server_ctx.snapshot_dir; + const archive_file = try snapshot_dir.openFile(archive_name, .{}); + errdefer archive_file.close(); + + const file_size = try archive_file.getEndPos(); + + const pipe_r, const pipe_w = try std.posix.pipe(); + errdefer std.posix.close(pipe_w); + errdefer std.posix.close(pipe_r); + + break :blk .{ .sfd = .{ + .file_lg = full_info_lg, + .file = archive_file, + .file_size = file_size, + + .pipe_w = pipe_w, + .pipe_r = pipe_r, + } }; + }, + else => comptime unreachable, + }; + + entry_data.state = .{ .send_file_head = .{ + .sent_bytes = 0, + .data = sfh_data, + } }; + const sfh = &entry_data.state.send_file_head; + switch (try sfh.computeAndMaybePrepSend(entry, &liou.io_uring)) { + .all_sent => unreachable, // we know this for certain + .sending_more => {}, + } + return; + }, + .unrecognized => {}, + }, + + else => {}, + } + + entry_data.state = .{ + .send_no_body = EntryState.SendNoBody.initHttStatus( + .@"HTTP/1.0", + .not_found, + ), + }; + const snb = &entry_data.state.send_no_body; + try snb.prepSend(entry, &liou.io_uring); + return; +} + +const OurCqe = extern struct { + user_data: Entry, + res: i32, + flags: u32, + + fn fromCqe(cqe: std.os.linux.io_uring_cqe) OurCqe { + return .{ + .user_data = @bitCast(cqe.user_data), + .res = cqe.res, + .flags = cqe.flags, + }; + } + + fn asCqe(self: OurCqe) std.os.linux.io_uring_cqe { + return .{ + .user_data = @bitCast(self.user_data), + .res = self.res, + .flags = self.flags, + }; + } + + fn err(self: OurCqe) std.os.linux.E { + return self.asCqe().err(); + } +}; + +const Entry = packed struct(u64) { + /// If null, this is an `accept` entry. + ptr: ?*EntryData, + + const ACCEPT: Entry = .{ .ptr = null }; + + fn deinit(self: Entry, allocator: std.mem.Allocator) void { + const ptr = self.ptr orelse return; + ptr.deinit(allocator); + allocator.destroy(ptr); + } +}; + +const EntryData = struct { + buffer: []u8, + stream: std.net.Stream, + state: EntryState, + + fn deinit(self: *EntryData, allocator: std.mem.Allocator) void { + self.state.deinit(); + allocator.free(self.buffer); + self.stream.close(); + } +}; + +const EntryState = union(enum) { + recv_head: RecvHead, + recv_body: RecvBody, + send_file_head: SendFileHead, + send_file_body: SendFileBody, + send_no_body: SendNoBody, + + const INIT: EntryState = .{ + .recv_head = .{ + .end = 0, + .parser = .{}, + }, + }; + + fn deinit(self: *EntryState) void { + switch (self.*) { + .recv_head => {}, + .recv_body => {}, + .send_file_head => |*sfh| sfh.deinit(), + .send_file_body => |*sfb| sfb.deinit(), + .send_no_body => {}, + } + } + + const RecvHead = struct { + end: usize, + parser: std.http.HeadParser, + + fn prepRecv( + self: *const RecvHead, + entry: Entry, + io_uring: *IoUring, + ) GetSqeRetryError!void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.recv_head); + + const usable_buffer = entry_ptr.buffer[self.end..]; + const sqe = try getSqeRetry(io_uring); + sqe.prep_recv(entry_ptr.stream.handle, usable_buffer, 0); + sqe.user_data = @bitCast(entry); + } + }; + + const RecvBody = struct { + head_info: requests.HeadInfo, + /// The current number of content bytes read into the buffer. + content_end: usize, + }; + + const SendFileData = struct { + file_lg: requests.GetRequestTargetResolved.SnapshotReadLock, + file: std.fs.File, + file_size: u64, + + pipe_w: std.os.linux.fd_t, + pipe_r: std.os.linux.fd_t, + + fn deinit(self: *SendFileData) void { + self.file.close(); + self.file_lg.unlock(); + std.posix.close(self.pipe_w); + std.posix.close(self.pipe_r); + } + }; + + const SendFileHead = struct { + sent_bytes: u64, + data: Data, + + const Data = union(enum) { + /// Just responding to a HEAD request. + file_size: u64, + sfd: SendFileData, + }; + + fn deinit(self: *SendFileHead) void { + switch (self.data) { + .sfd => |*sfd| sfd.deinit(), + .file_size => {}, + } + } + + /// If `self.sent_bytes` is equal to the number of rendered head bytes, this + /// will return `.all_sent`, which means it won't have queued any SQEs; otherwise, + /// it is guaranteed to return `.sending_more` - the latter would always be the + /// case when `self.sent_bytes == 0` for example. + fn computeAndMaybePrepSend( + self: *SendFileHead, + entry: Entry, + io_uring: *IoUring, + ) GetSqeRetryError!enum { + /// The head has been fully sent already, no send was prepped. + all_sent, + /// There is still more head data to send. + sending_more, + } { + const entry_data = entry.ptr.?; + std.debug.assert(self == &entry_data.state.send_file_head); + + const rendered_len = blk: { + // render segments of the head into our buffer, + // sending them as they become rendered. + + var ww = sig.utils.io.WindowedWriter.init(entry_data.buffer, self.sent_bytes); + var cw = std.io.countingWriter(ww.writer()); + const writer = cw.writer(); + + const status: std.http.Status = .ok; + writer.print("{[version]s} {[status]d}{[space]s}{[phrase]s}\r\n", .{ + .version = @tagName(std.http.Version.@"HTTP/1.0"), + .status = @intFromEnum(status), + .space = if (status.phrase() != null) " " else "", + .phrase = if (status.phrase()) |str| str else "", + }) catch |err| switch (err) {}; + + const file_size = switch (self.data) { + .sfd => |sfd| sfd.file_size, + .file_size => |file_size| file_size, + }; + writer.print("Content-Length: {d}\r\n", .{file_size}) catch |err| switch (err) {}; + + writer.writeAll("\r\n") catch |err| switch (err) {}; + + if (self.sent_bytes == cw.bytes_written) return .all_sent; + std.debug.assert(self.sent_bytes < cw.bytes_written); + break :blk ww.end_index; + }; + + const sqe = try getSqeRetry(io_uring); + sqe.prep_send(entry_data.stream.handle, entry_data.buffer[0..rendered_len], 0); + sqe.user_data = @bitCast(entry); + + return .sending_more; + } + }; + + const SendFileBody = struct { + sfd: SendFileData, + spliced_to_pipe: u64, + spliced_to_socket: u64, + which: Which, + + const Which = enum { + to_pipe, + to_socket, + }; + + fn deinit(self: *SendFileBody) void { + self.sfd.deinit(); + } + + fn prepSpliceFileToPipe( + self: *const SendFileBody, + entry: Entry, + io_uring: *IoUring, + ) GetSqeRetryError!void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_file_body); + std.debug.assert(self.which == .to_pipe); + + const sqe = try getSqeRetry(io_uring); + sqe.prep_splice( + self.sfd.file.handle, + self.spliced_to_pipe, + self.sfd.pipe_w, + std.math.maxInt(u64), + self.sfd.file_size - self.spliced_to_pipe, + ); + sqe.user_data = @bitCast(entry); + } + + fn prepSplicePipeToSocket( + self: *const SendFileBody, + entry: Entry, + io_uring: *IoUring, + ) GetSqeRetryError!void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_file_body); + std.debug.assert(self.which == .to_socket); + + const stream = entry_ptr.stream; + + const sqe = try getSqeRetry(io_uring); + sqe.prep_splice( + self.sfd.pipe_r, + std.math.maxInt(u64), + stream.handle, + std.math.maxInt(u64), + self.sfd.file_size - self.spliced_to_socket, + ); + sqe.user_data = @bitCast(entry); + } + }; + + const SendNoBody = struct { + /// Should be a statically-lived string. + head: []const u8, + end_index: usize, + + fn initString(comptime str: []const u8) SendNoBody { + return .{ + .head = str, + .end_index = 0, + }; + } + + fn initHttStatus( + comptime version: std.http.Version, + comptime status: std.http.Status, + ) SendNoBody { + const head = comptime std.fmt.comptimePrint("{s} {d}{s}\r\n\r\n", .{ + @tagName(version), + @intFromEnum(status), + if (status.phrase()) |phrase| " " ++ phrase else "", + }); + return initString(head); + } + + fn prepSend( + self: *const SendNoBody, + entry: Entry, + io_uring: *IoUring, + ) GetSqeRetryError!void { + const entry_ptr = entry.ptr.?; + std.debug.assert(self == &entry_ptr.state.send_no_body); + + const sqe = try getSqeRetry(io_uring); + sqe.prep_send(entry_ptr.stream.handle, self.head[self.end_index..], 0); + sqe.user_data = @bitCast(entry); + } + }; +}; + +fn getSocketName( + socket_handle: std.posix.socket_t, +) std.posix.GetSockNameError!std.net.Address { + var addr: std.net.Address = .{ .any = undefined }; + var addr_len: std.posix.socklen_t = @sizeOf(@TypeOf(addr.any)); + try std.posix.getsockname(socket_handle, &addr.any, &addr_len); + return addr; +} + +const GetSqeRetryError = IouEnterError; + +/// Try to `get_sqe`; if the submission queue is too full for that, call `submit()`, +/// and then try again, and panic if there's still somehow no room. +fn getSqeRetry(io_uring: *std.os.linux.IoUring) GetSqeRetryError!*std.os.linux.io_uring_sqe { + if (io_uring.get_sqe()) |sqe| return sqe else |_| {} + _ = try io_uring.submit(); + return io_uring.get_sqe() catch + std.debug.panic("Failed to queue entry after flushing submission queue", .{}); +} + +const IouInitError = std.posix.MMapError || error{ + EntriesZero, + EntriesNotPowerOfTwo, + + ParamsOutsideAccessibleAddressSpace, + ArgumentsInvalid, + ProcessFdQuotaExceeded, + SystemFdQuotaExceeded, + SystemResources, + + PermissionDenied, + SystemOutdated, +}; + +/// Extracted from `std.os.linux.IoUring.enter`. +const IouEnterError = error{ + /// The kernel was unable to allocate memory or ran out of resources for the request. + /// The application should wait for some completions and try again. + SystemResources, + /// The SQE `fd` is invalid, or IOSQE_FIXED_FILE was set but no files were registered. + FileDescriptorInvalid, + /// The file descriptor is valid, but the ring is not in the right state. + /// See io_uring_register(2) for how to enable the ring. + FileDescriptorInBadState, + /// The application attempted to overcommit the number of requests it can have pending. + /// The application should wait for some completions and try again. + CompletionQueueOvercommitted, + /// The SQE is invalid, or valid but the ring was setup with IORING_SETUP_IOPOLL. + SubmissionQueueEntryInvalid, + /// The buffer is outside the process' accessible address space, or IORING_OP_READ_FIXED + /// or IORING_OP_WRITE_FIXED was specified but no buffers were registered, or the range + /// described by `addr` and `len` is not within the buffer registered at `buf_index`: + BufferInvalid, + RingShuttingDown, + /// The kernel believes our `self.fd` does not refer to an io_uring instance, + /// or the opcode is valid but not supported by this kernel (more likely): + OpcodeNotSupported, + /// The operation was interrupted by a delivery of a signal before it could complete. + /// This can happen while waiting for events with IORING_ENTER_GETEVENTS: + SignalInterrupt, +} || std.posix.UnexpectedError; diff --git a/src/rpc/server/requests.zig b/src/rpc/server/requests.zig new file mode 100644 index 000000000..993b431a5 --- /dev/null +++ b/src/rpc/server/requests.zig @@ -0,0 +1,182 @@ +//! This file defines most of the shared logic for the bounds and handling +//! of RPC requests. + +const builtin = @import("builtin"); +const std = @import("std"); +const sig = @import("../../sig.zig"); + +const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; +const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; +const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; + +/// A single request body cannot be larger than this; +/// a single chunk in a chunked request body cannot be larger than this, +/// but all together they may be allowed to be larger than this, +/// depending on the request. +pub const MAX_REQUEST_BODY_SIZE: usize = 50 * 1024; // 50 KiB + +const LOGGER_SCOPE = "rpc.server.requests"; + +/// All of the relevant information from a request head parsed into a narrow +/// format that is comprised of bounded data and can be copied by value. +pub const HeadInfo = struct { + method: std.http.Method, + target: TargetBoundedStr, + content_len: ?u64, + content_type: ?ContentType, + transfer_encoding: std.http.TransferEncoding, + content_encoding: std.http.ContentEncoding, + + const StdHead = std.http.Server.Request.Head; + + pub const ParseError = StdHead.ParseError || ParseFromStdHeadError; + + pub fn parse(head_bytes: []const u8) ParseError!HeadInfo { + const parsed_head = try StdHead.parse(head_bytes); + // at the time of writing, this always holds true for the result of `Head.parse`. + std.debug.assert(parsed_head.compression == .none); + return try parseFromStdHead(parsed_head); + } + + pub const ParseFromStdHeadError = error{ + RequestTargetTooLong, + RequestContentTypeUnrecognized, + }; + + pub fn parseFromStdHead(std_head: StdHead) ParseFromStdHeadError!HeadInfo { + // TODO: should we care about these? + _ = std_head.version; + _ = std_head.expect; + _ = std_head.keep_alive; + + const target = TargetBoundedStr.fromSlice(std_head.target) catch + return error.RequestTargetTooLong; + + const content_type: ?ContentType = ct: { + const str = std_head.content_type orelse break :ct null; + break :ct std.meta.stringToEnum(ContentType, str) orelse + return error.RequestContentTypeUnrecognized; + }; + + return .{ + .method = std_head.method, + .target = target, + .content_len = std_head.content_length, + .content_type = content_type, + .transfer_encoding = std_head.transfer_encoding, + .content_encoding = std_head.transfer_compression, + }; + } +}; + +pub const ContentType = enum(u8) { + @"application/json", +}; + +pub const MAX_TARGET_LEN: usize = blk: { + const SnapSpec = IncrementalSnapshotFileInfo.SnapshotArchiveNameFmtSpec; + break :blk "/".len + SnapSpec.fmtLenValue(.{ + .base_slot = std.math.maxInt(sig.core.Slot), + .slot = std.math.maxInt(sig.core.Slot), + .hash = sig.core.Hash.base58String(.{ .data = .{255} ** sig.core.Hash.SIZE }).constSlice(), + }); +}; +pub const TargetBoundedStr = std.BoundedArray(u8, MAX_TARGET_LEN); + +pub const GetRequestTargetResolved = union(enum) { + unrecognized, + full_snapshot: struct { FullSnapshotFileInfo, SnapshotReadLock }, + inc_snapshot: struct { IncrementalSnapshotFileInfo, SnapshotReadLock }, + + // TODO: also handle the snapshot archive aliases & other routes + + pub const SnapshotReadLock = sig.sync.RwMux(?SnapshotGenerationInfo).RLockGuard; +}; + +/// Resolve a `GET` request target. +pub fn getRequestTargetResolve( + unscoped_logger: sig.trace.Logger, + target: []const u8, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) GetRequestTargetResolved { + const logger = unscoped_logger.withScope(LOGGER_SCOPE); + + if (!std.mem.startsWith(u8, target, "/")) return .unrecognized; + const path = target[1..]; + + const is_snapshot_archive_like = + !std.meta.isError(FullSnapshotFileInfo.parseFileNameTarZst(path)) or + !std.meta.isError(IncrementalSnapshotFileInfo.parseFileNameTarZst(path)); + + if (is_snapshot_archive_like) { + // we hold the lock for the entirety of this process in order to prevent + // the snapshot generation process from deleting the associated snapshot. + const maybe_latest_snapshot_gen_info, // + var latest_snapshot_info_lg // + = latest_snapshot_gen_info_rw.readWithLock(); + errdefer latest_snapshot_info_lg.unlock(); + + const full_info: ?FullSnapshotFileInfo, // + const inc_info: ?IncrementalSnapshotFileInfo // + = blk: { + const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse + break :blk .{ null, null }; + const latest_full = latest_snapshot_gen_info.full; + const full_info: FullSnapshotFileInfo = .{ + .slot = latest_full.slot, + .hash = latest_full.hash, + }; + const latest_incremental = latest_snapshot_gen_info.inc orelse + break :blk .{ full_info, null }; + const inc_info: IncrementalSnapshotFileInfo = .{ + .base_slot = latest_full.slot, + .slot = latest_incremental.slot, + .hash = latest_incremental.hash, + }; + break :blk .{ full_info, inc_info }; + }; + + logger.debug().logf("Available full: {?s}", .{ + if (full_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + logger.debug().logf("Available inc: {?s}", .{ + if (inc_info) |info| info.snapshotArchiveName().constSlice() else null, + }); + + if (full_info) |full| { + const full_archive_name_bounded = full.snapshotArchiveName(); + const full_archive_name = full_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, full_archive_name)) { + return .{ .full_snapshot = .{ full, latest_snapshot_info_lg } }; + } + } + + if (inc_info) |inc| { + const inc_archive_name_bounded = inc.snapshotArchiveName(); + const inc_archive_name = inc_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, path, inc_archive_name)) { + return .{ .inc_snapshot = .{ inc, latest_snapshot_info_lg } }; + } + } + } + + return .unrecognized; +} + +pub fn methodFmt(method: std.http.Method) MethodFmt { + return .{ .method = method }; +} + +pub const MethodFmt = struct { + method: std.http.Method, + pub fn format( + self: MethodFmt, + comptime fmt_str: []const u8, + fmt_options: std.fmt.FormatOptions, + writer: anytype, + ) @TypeOf(writer).Error!void { + _ = fmt_options; + if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, self); + try self.method.write(writer); + } +}; diff --git a/src/rpc/server/server.zig b/src/rpc/server/server.zig new file mode 100644 index 000000000..13e91a97b --- /dev/null +++ b/src/rpc/server/server.zig @@ -0,0 +1,295 @@ +//! RPC Server implementation. +//! +//! This file defines and exposes the relevant public API for +//! the RPC Server, as well as the internal API for backends +//! and any other internal code. + +const std = @import("std"); +const sig = @import("../../sig.zig"); + +pub const connection = @import("connection.zig"); +pub const requests = @import("requests.zig"); + +pub const basic = @import("basic.zig"); +pub const LinuxIoUring = @import("linux_io_uring.zig").LinuxIoUring; + +const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; + +/// The minimum buffer read size. +pub const MIN_READ_BUFFER_SIZE = 4096; + +const LOGGER_SCOPE = "rpc.server"; + +/// The work pool is a tagged union, representing one of various possible backends. +/// It acts merely as a reference to a specific backend's state, or a tag for stateless +/// backends. +pub const WorkPool = union(enum) { + basic, + linux_io_uring: if (LinuxIoUring.can_use) *LinuxIoUring else noreturn, +}; + +/// The basic state required for the server to operate. +pub const Context = struct { + allocator: std.mem.Allocator, + logger: sig.trace.log.ScopedLogger(LOGGER_SCOPE), + snapshot_dir: std.fs.Dir, + latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), + + /// Wait group for all currently running tasks, used to wait for + /// all of them to finish before deinitializing. + wait_group: std.Thread.WaitGroup, + tcp: std.net.Server, + /// Must not be mutated. + read_buffer_size: u32, + + /// The returned result must be pinned to a memory location before calling any methods. + pub fn init(params: struct { + /// Must be a thread-safe allocator. + allocator: std.mem.Allocator, + logger: sig.trace.Logger, + + /// Not closed by the `Server`, but must live at least as long as it. + snapshot_dir: std.fs.Dir, + /// Should reflect the latest generated snapshot eligible for propagation at any + /// given time with respect to the contents of the specified `snapshot_dir`. + latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), + + /// The size for the read buffer allocated to every request. + /// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. + read_buffer_size: u32, + /// The socket address to listen on for incoming HTTP and/or RPC requests. + socket_addr: std.net.Address, + /// See `@FieldType(std.net.Address.ListenOptions, "reuse_address")`. + reuse_address: bool = false, + }) std.net.Address.ListenError!Context { + var tcp_server = try params.socket_addr.listen(.{ + .force_nonblocking = true, + .reuse_address = params.reuse_address, + }); + errdefer tcp_server.deinit(); + + return .{ + .allocator = params.allocator, + .logger = params.logger.withScope(LOGGER_SCOPE), + .snapshot_dir = params.snapshot_dir, + .latest_snapshot_gen_info = params.latest_snapshot_gen_info, + + .wait_group = .{}, + .read_buffer_size = @max(params.read_buffer_size, MIN_READ_BUFFER_SIZE), + .tcp = tcp_server, + }; + } + + /// Blocks until all tasks are completed, and then closes the server context. + /// Does not force the server to exit. + pub fn joinDeinit(self: *Context) void { + self.wait_group.wait(); + self.tcp.deinit(); + } +}; + +/// Spawn `serve` as a separate thread. +pub fn serveSpawn( + exit: *std.atomic.Value(bool), + ctx: *Context, + work_pool: WorkPool, +) std.Thread.SpawnError!std.Thread { + return try std.Thread.spawn(.{}, serve, .{ exit, ctx, work_pool }); +} + +pub const ServeError = + basic.AcceptAndServeConnectionError || + LinuxIoUring.AcceptAndServeConnectionsError; + +/// Until `exit.load(.acquire)`, accepts and serves connections in a loop. +pub fn serve( + /// The exit condition. + exit: *std.atomic.Value(bool), + /// The context to operate with. + ctx: *Context, + /// The pool to dispatch work to. + work_pool: WorkPool, +) ServeError!void { + while (!exit.load(.acquire)) { + switch (work_pool) { + .basic => try basic.acceptAndServeConnection(ctx), + .linux_io_uring => |linux| try linux.acceptAndServeConnections(ctx), + } + } +} + +test serveSpawn { + if (sig.build_options.no_network_tests) return error.SkipZigTest; + const allocator = std.testing.allocator; + + var prng = std.Random.DefaultPrng.init(0); + const random = prng.random(); + + var tmp_dir_root = std.testing.tmpDir(.{}); + defer tmp_dir_root.cleanup(); + const tmp_dir = tmp_dir_root.dir; + + // const logger_unscoped: sig.trace.Logger = .{ .direct_print = .{ .max_level = .trace } }; + const logger_unscoped: sig.trace.Logger = .noop; + + const logger = logger_unscoped.withScope(@src().fn_name); + + // the directory into which the snapshots will be unpacked. + var unpacked_snap_dir = try tmp_dir.makeOpenPath("snapshot", .{}); + defer unpacked_snap_dir.close(); + + // the source from which `fundAndUnpackTestSnapshots` will unpack the snapshots. + var test_data_dir = try std.fs.cwd().openDir(sig.TEST_DATA_DIR, .{ .iterate = true }); + defer test_data_dir.close(); + + const snap_files = try sig.accounts_db.db.findAndUnpackTestSnapshots( + std.Thread.getCpuCount() catch 1, + unpacked_snap_dir, + ); + + var latest_snapshot_gen_info = sig.sync.RwMux(?SnapshotGenerationInfo).init(blk: { + const FullAndIncrementalManifest = sig.accounts_db.snapshots.FullAndIncrementalManifest; + const all_snap_fields = try FullAndIncrementalManifest.fromFiles( + allocator, + logger.unscoped(), + unpacked_snap_dir, + snap_files, + ); + defer all_snap_fields.deinit(allocator); + + break :blk .{ + .full = .{ + .slot = snap_files.full.slot, + .hash = snap_files.full.hash, + .capitalization = all_snap_fields.full.bank_fields.capitalization, + }, + .inc = inc: { + const inc = all_snap_fields.incremental orelse break :inc null; + // if the incremental snapshot field is not null, these shouldn't be either + const inc_info = snap_files.incremental_info.?; + const inc_persist = inc.bank_extra.snapshot_persistence.?; + break :inc .{ + .slot = inc_info.slot, + .hash = inc_info.hash, + .capitalization = inc_persist.incremental_capitalization, + }; + }, + }; + }); + + const rpc_port = random.intRangeLessThan(u16, 8_000, 10_000); + const sock_addr = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, rpc_port); + var server_ctx = try Context.init(.{ + .allocator = allocator, + .logger = logger.unscoped(), + .snapshot_dir = test_data_dir, + .latest_snapshot_gen_info = &latest_snapshot_gen_info, + .socket_addr = sock_addr, + .read_buffer_size = 4096, + .reuse_address = true, + }); + defer server_ctx.joinDeinit(); + + var maybe_liou = try LinuxIoUring.init(&server_ctx); + // TODO: currently `if (a) |*b|` on `a: ?noreturn` causes analysis of + // the unwrap block, even though `if (a) |b|` doesn't; fixed in 0.14 + defer if (maybe_liou != null) maybe_liou.?.deinit(); + + for ([_]?WorkPool{ + .basic, + // TODO: see above TODO about `if (a) |*b|` on `?noreturn`. + if (maybe_liou != null) .{ .linux_io_uring = &maybe_liou.? } else null, + }) |maybe_work_pool| { + const work_pool = maybe_work_pool orelse continue; + logger.info().logf("Running with {s}", .{@tagName(work_pool)}); + + var exit = std.atomic.Value(bool).init(false); + const serve_thread = try serveSpawn(&exit, &server_ctx, work_pool); + defer serve_thread.join(); + defer exit.store(true, .release); + + try testExpectSnapshotResponse( + allocator, + test_data_dir, + server_ctx.tcp.listen_address.getPort(), + .full, + snap_files.full, + ); + + if (snap_files.incremental()) |inc| { + try testExpectSnapshotResponse( + allocator, + test_data_dir, + server_ctx.tcp.listen_address.getPort(), + .incremental, + inc, + ); + } + } +} + +fn testExpectSnapshotResponse( + allocator: std.mem.Allocator, + snap_dir: std.fs.Dir, + rpc_port: u16, + comptime kind: enum { full, incremental }, + snap_info: switch (kind) { + .full => sig.accounts_db.snapshots.FullSnapshotFileInfo, + .incremental => sig.accounts_db.snapshots.IncrementalSnapshotFileInfo, + }, +) !void { + const snap_name_bounded = snap_info.snapshotArchiveName(); + const snap_name = snap_name_bounded.constSlice(); + + const expected_file = try snap_dir.openFile(snap_name, .{}); + defer expected_file.close(); + + const expected_data: []align(std.mem.page_size) const u8 = try std.posix.mmap( + null, + try expected_file.getEndPos(), + std.posix.PROT.READ, + .{ .TYPE = .PRIVATE }, + expected_file.handle, + 0, + ); + defer std.posix.munmap(expected_data); + + const snap_url_str_bounded = sig.utils.fmt.boundedFmt( + "http://localhost:{d}/{s}", + .{ rpc_port, sig.utils.fmt.boundedString(&snap_name_bounded) }, + ); + const snap_url = try std.Uri.parse(snap_url_str_bounded.constSlice()); + + const actual_data = try testDownloadSelfSnapshot(allocator, snap_url); + defer allocator.free(actual_data); + + try std.testing.expectEqualSlices(u8, expected_data, actual_data); +} + +fn testDownloadSelfSnapshot( + allocator: std.mem.Allocator, + snap_url: std.Uri, +) ![]const u8 { + var client: std.http.Client = .{ .allocator = allocator }; + defer client.deinit(); + + var server_header_buffer: [4096 * 16]u8 = undefined; + var request = try client.open(.GET, snap_url, .{ + .server_header_buffer = &server_header_buffer, + }); + defer request.deinit(); + + try request.send(); + try request.finish(); + try request.wait(); + + const content_len = request.response.content_length.?; + const reader = request.reader(); + + const response_content = try reader.readAllAlloc(allocator, 1 << 32); + errdefer allocator.free(response_content); + + try std.testing.expectEqual(content_len, response_content.len); + + return response_content; +} diff --git a/src/sig.zig b/src/sig.zig index 874f88766..a12ab1ba2 100644 --- a/src/sig.zig +++ b/src/sig.zig @@ -20,6 +20,7 @@ pub const trace = @import("trace/lib.zig"); pub const transaction_sender = @import("transaction_sender/lib.zig"); pub const utils = @import("utils/lib.zig"); pub const version = @import("version/version.zig"); +pub const build_options = @import("build-options"); pub const VALIDATOR_DIR = "validator/"; /// subdirectory of {VALIDATOR_DIR} which contains the accounts database diff --git a/src/utils/fmt.zig b/src/utils/fmt.zig index bfd739a57..e8a3f519b 100644 --- a/src/utils/fmt.zig +++ b/src/utils/fmt.zig @@ -36,7 +36,7 @@ pub fn BoundedSpec(comptime spec: []const u8) type { /// try expectEqual("foo-255".len, boundedLenValue("{[a]s}-{[b]d}", .{ .a = "foo", .b = 255 })); /// ``` pub inline fn fmtLenValue(comptime args_value: anytype) usize { - comptime return fmtLen(fmt_str, @TypeOf(args_value)); + comptime return fmtLen(@TypeOf(args_value)); } pub fn BoundedArray(comptime Args: type) type { diff --git a/src/utils/io.zig b/src/utils/io.zig index becea8b07..f2d72dcab 100644 --- a/src/utils/io.zig +++ b/src/utils/io.zig @@ -87,3 +87,87 @@ fn NarrowAnyStream(comptime Error: type) type { } }; } + +/// Writer which captures only an offset window of data into a buffer. +/// This can be useful for incrementally capturing formatted data. +pub const WindowedWriter = struct { + remaining_to_ignore: u64, + end_index: usize, + buffer: []u8, + + pub fn init( + buffer: []u8, + start_bytes_to_ignore: u64, + ) WindowedWriter { + std.debug.assert(buffer.len != 0); + return .{ + .remaining_to_ignore = start_bytes_to_ignore, + .end_index = 0, + .buffer = buffer, + }; + } + + pub fn reset(self: *WindowedWriter, start_bytes_to_ignore: usize) void { + self.remaining_to_ignore = start_bytes_to_ignore; + self.end_index = 0; + } + + pub fn write(self: *WindowedWriter, bytes: []const u8) void { + const bytes_to_skip = @min(self.remaining_to_ignore, bytes.len); + self.remaining_to_ignore -|= bytes.len; + + const src_target_bytes = bytes[bytes_to_skip..]; + const writable = self.buffer[self.end_index..]; + + const amt = @min(writable.len, src_target_bytes.len); + @memcpy(writable[0..amt], src_target_bytes[0..amt]); + self.end_index += amt; + } + + pub const Writer = std.io.GenericWriter(*WindowedWriter, error{}, writerFn); + pub fn writer(self: *WindowedWriter) Writer { + return .{ .context = self }; + } + + fn writerFn(self: *WindowedWriter, bytes: []const u8) error{}!usize { + self.write(bytes); + return bytes.len; + } +}; + +fn testWindowedWriter( + comptime kind: enum { bin, str }, + params: struct { start: usize, size: usize }, + data: []const u8, + expected: []const u8, +) !void { + const buffer = try std.testing.allocator.alloc(u8, params.size); + defer std.testing.allocator.free(buffer); + + var ww = WindowedWriter.init(buffer, params.start); + for (0..data.len) |split_i| { + ww.reset(params.start); + ww.write(data[0..split_i]); + ww.write(data[split_i..]); + try std.testing.expectEqual(expected.len, ww.end_index); + switch (kind) { + .bin => try std.testing.expectEqualSlices(u8, expected, ww.buffer), + .str => try std.testing.expectEqualStrings(expected, ww.buffer), + } + } +} + +test WindowedWriter { + try testWindowedWriter(.str, .{ .start = 0, .size = 3 }, "foo\n", "foo"); + try testWindowedWriter(.str, .{ .start = 1, .size = 2 }, "foo\n", "oo"); + try testWindowedWriter(.str, .{ .start = 1, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 2, .size = 1 }, "foo\n", "o"); + + try testWindowedWriter(.str, .{ .start = 1, .size = 3 }, "foo\n", "oo\n"); + try testWindowedWriter(.str, .{ .start = 2, .size = 2 }, "foo\n", "o\n"); + + try testWindowedWriter(.str, .{ .start = 0, .size = 1 }, "foo\n", "f"); + try testWindowedWriter(.str, .{ .start = 1, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 2, .size = 1 }, "foo\n", "o"); + try testWindowedWriter(.str, .{ .start = 3, .size = 1 }, "foo\n", "\n"); +}