From 645fe1a8070fd033f1a957bef69a23f69e9f6d54 Mon Sep 17 00:00:00 2001 From: x19 <100000306+0xNineteen@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:12:57 -0500 Subject: [PATCH] fix(adb/threadpool): remove waitgroups and fix segfault (#522) * fix(adb/threadpool): remove waitgroups and fix segfault * address comments * address race in geyser (tsan failing on ci) --- src/accountsdb/db.zig | 112 ++++++++-------------- src/cmd.zig | 1 - src/geyser/benchmark.zig | 10 +- src/geyser/core.zig | 7 +- src/geyser/main.zig | 15 ++- src/prometheus/registry.zig | 7 +- src/shred_network/repair_service.zig | 7 +- src/utils/thread.zig | 138 ++++++++++++++++++--------- 8 files changed, 162 insertions(+), 135 deletions(-) diff --git a/src/accountsdb/db.zig b/src/accountsdb/db.zig index ca87086f9..bce1ea200 100644 --- a/src/accountsdb/db.zig +++ b/src/accountsdb/db.zig @@ -294,13 +294,14 @@ pub const AccountsDB = struct { try self.fastload(fastload_dir, collapsed_manifest.accounts_db_fields); self.logger.info().logf("fastload: total time: {s}", .{timer.read()}); } else { - const load_duration = try self.loadFromSnapshot( + var load_timer = try sig.time.Timer.start(); + try self.loadFromSnapshot( collapsed_manifest.accounts_db_fields, n_threads, allocator, accounts_per_file_estimate, ); - self.logger.info().logf("loadFromSnapshot: total time: {s}", .{load_duration}); + self.logger.info().logf("loadFromSnapshot: total time: {s}", .{load_timer.read()}); } // no need to re-save if we just loaded from a fastload @@ -404,7 +405,7 @@ pub const AccountsDB = struct { /// needs to be a thread-safe allocator per_thread_allocator: std.mem.Allocator, accounts_per_file_estimate: u64, - ) !sig.time.Duration { + ) !void { self.logger.info().log("running loadFromSnapshot..."); // used to read account files @@ -429,27 +430,6 @@ pub const AccountsDB = struct { bhs.accumulate(snapshot_manifest.bank_hash_info.stats); } - var timer = try sig.time.Timer.start(); - // short path - if (n_threads == 1) { - try self.loadAndVerifyAccountsFiles( - accounts_dir, - accounts_per_file_estimate, - snapshot_manifest.file_map, - 0, - n_account_files, - true, - ); - - // if geyser, send end of data signal - if (self.geyser_writer) |geyser_writer| { - const end_of_snapshot: sig.geyser.core.VersionedAccountPayload = .EndOfSnapshotLoading; - try geyser_writer.writePayloadToPipe(end_of_snapshot); - } - - return timer.read(); - } - // setup the parallel indexing const loading_threads = try self.allocator.alloc(AccountsDB, n_parse_threads); defer self.allocator.free(loading_threads); @@ -457,12 +437,11 @@ pub const AccountsDB = struct { try initLoadingThreads(per_thread_allocator, loading_threads, self); defer deinitLoadingThreads(per_thread_allocator, loading_threads); - self.logger.info().logf("[{d} threads]: reading and indexing accounts...", .{n_parse_threads}); - { - var wg: std.Thread.WaitGroup = .{}; - defer wg.wait(); - try spawnThreadTasks(loadAndVerifyAccountsFilesMultiThread, .{ - .wg = &wg, + self.logger.info().logf("[{d} threads]: running loadAndVerifyAccountsFiles...", .{n_parse_threads}); + try spawnThreadTasks( + self.allocator, + loadAndVerifyAccountsFilesMultiThread, + .{ .data_len = n_account_files, .max_threads = n_parse_threads, .params = .{ @@ -471,8 +450,8 @@ pub const AccountsDB = struct { snapshot_manifest.file_map, accounts_per_file_estimate, }, - }); - } + }, + ); // if geyser, send end of data signal if (self.geyser_writer) |geyser_writer| { @@ -483,9 +462,6 @@ pub const AccountsDB = struct { var merge_timer = try sig.time.Timer.start(); try self.mergeMultipleDBs(loading_threads, n_combine_threads); self.logger.debug().logf("mergeMultipleDBs: total time: {}", .{merge_timer.read()}); - - self.logger.debug().logf("loadFromSnapshot: total time: {s}", .{timer.read()}); - return timer.read(); } /// Initializes a slice of children `AccountsDB`s, used to divide the work of loading from a snapshot. @@ -777,10 +753,7 @@ pub const AccountsDB = struct { ) !void { self.logger.info().logf("[{d} threads]: running mergeMultipleDBs...", .{n_threads}); - var merge_indexes_wg: std.Thread.WaitGroup = .{}; - defer merge_indexes_wg.wait(); - try spawnThreadTasks(mergeThreadIndexesMultiThread, .{ - .wg = &merge_indexes_wg, + try spawnThreadTasks(self.allocator, mergeThreadIndexesMultiThread, .{ .data_len = self.account_index.pubkey_ref_map.numberOfShards(), .max_threads = n_threads, .params = .{ @@ -954,35 +927,20 @@ pub const AccountsDB = struct { // split processing the bins over muliple threads self.logger.info().logf( - "collecting hashes from accounts using {} threads...", + "[{} threads] collecting hashes from accounts", .{n_threads}, ); - if (n_threads == 1) { - try getHashesFromIndex( + try spawnThreadTasks(self.allocator, getHashesFromIndexMultiThread, .{ + .data_len = self.account_index.pubkey_ref_map.numberOfShards(), + .max_threads = n_threads, + .params = .{ self, config, - self.account_index.pubkey_ref_map.shards, self.allocator, - &hashes[0], - &lamports[0], - true, - ); - } else { - var wg: std.Thread.WaitGroup = .{}; - defer wg.wait(); - try spawnThreadTasks(getHashesFromIndexMultiThread, .{ - .wg = &wg, - .data_len = self.account_index.pubkey_ref_map.numberOfShards(), - .max_threads = n_threads, - .params = .{ - self, - config, - self.allocator, - hashes, - lamports, - }, - }); - } + hashes, + lamports, + }, + }); self.logger.debug().logf("collecting hashes from accounts took: {s}", .{timer.read()}); timer.reset(); @@ -3215,7 +3173,7 @@ pub fn getAccountPerFileEstimateFromCluster( cluster: sig.core.Cluster, ) error{NotImplementedYet}!u64 { return switch (cluster) { - .testnet => 1_000, + .testnet => 500, else => error.NotImplementedYet, }; } @@ -3267,7 +3225,7 @@ fn testWriteSnapshotFull( var snap_fields = try SnapshotManifest.decodeFromBincode(allocator, manifest_file.reader()); defer snap_fields.deinit(allocator); - _ = try accounts_db.loadFromSnapshot(snap_fields.accounts_db_fields, 1, allocator, 1_500); + try accounts_db.loadFromSnapshot(snap_fields.accounts_db_fields, 1, allocator, 500); const snapshot_gen_info = try accounts_db.generateFullSnapshot(.{ .target_slot = slot, @@ -3306,7 +3264,7 @@ fn testWriteSnapshotIncremental( var snap_fields = try SnapshotManifest.decodeFromBincode(allocator, manifest_file.reader()); defer snap_fields.deinit(allocator); - _ = try accounts_db.loadFromSnapshot(snap_fields.accounts_db_fields, 1, allocator, 1_500); + try accounts_db.loadFromSnapshot(snap_fields.accounts_db_fields, 1, allocator, 500); const snapshot_gen_info = try accounts_db.generateIncrementalSnapshot(.{ .target_slot = slot, @@ -3474,7 +3432,7 @@ fn loadTestAccountsDB( }); errdefer accounts_db.deinit(); - _ = try accounts_db.loadFromSnapshot( + try accounts_db.loadFromSnapshot( manifest.accounts_db_fields, n_threads, allocator, @@ -3516,12 +3474,18 @@ test "geyser stream on load" { // start the geyser writer try geyser_writer.spawnIOLoop(); - const reader_handle = try std.Thread.spawn(.{}, sig.geyser.core.streamReader, .{ + var reader = try sig.geyser.GeyserReader.init( allocator, + geyser_pipe_path, + &geyser_exit, + .{}, + ); + defer reader.deinit(); + + const reader_handle = try std.Thread.spawn(.{}, sig.geyser.core.streamReader, .{ + &reader, .noop, &geyser_exit, - geyser_pipe_path, - null, null, }); defer reader_handle.join(); @@ -3543,11 +3507,11 @@ test "geyser stream on load" { }); defer accounts_db.deinit(); - _ = try accounts_db.loadFromSnapshot( + try accounts_db.loadFromSnapshot( snapshot.accounts_db_fields, 1, allocator, - 1_500, + 500, ); } @@ -4487,12 +4451,14 @@ pub const BenchmarkAccountsDBSnapshotLoad = struct { }); defer accounts_db.deinit(); - const loading_duration = try accounts_db.loadFromSnapshot( + var load_timer = try sig.time.Timer.start(); + try accounts_db.loadFromSnapshot( collapsed_manifest.accounts_db_fields, bench_args.n_threads, allocator, try getAccountPerFileEstimateFromCluster(bench_args.cluster), ); + const loading_duration = load_timer.read(); const fastload_save_duration = blk: { var timer = try sig.time.Timer.start(); diff --git a/src/cmd.zig b/src/cmd.zig index 437cdca6d..f3cbd929f 100644 --- a/src/cmd.zig +++ b/src/cmd.zig @@ -1501,7 +1501,6 @@ fn loadSnapshot( break :blk cli_n_threads_snapshot_load; } }; - logger.info().logf("n_threads_snapshot_load: {d}", .{n_threads_snapshot_load}); var accounts_db = try AccountsDB.init(.{ .allocator = allocator, diff --git a/src/geyser/benchmark.zig b/src/geyser/benchmark.zig index 3c0830149..867b8ff4c 100644 --- a/src/geyser/benchmark.zig +++ b/src/geyser/benchmark.zig @@ -61,10 +61,18 @@ pub fn runBenchmark(logger: sig.trace.Logger) !void { exit.* = std.atomic.Value(bool).init(false); + var reader = try sig.geyser.GeyserReader.init( + allocator, + PIPE_PATH, + exit, + .{}, + ); + defer reader.deinit(); + const reader_handle = try std.Thread.spawn( .{}, geyser.core.streamReader, - .{ allocator, logger, exit, PIPE_PATH, MEASURE_RATE, null }, + .{ &reader, logger, exit, MEASURE_RATE }, ); const writer_handle = try std.Thread.spawn(.{}, streamWriter, .{ allocator, exit }); diff --git a/src/geyser/core.zig b/src/geyser/core.zig index 5b2c15bc8..4b31aa68d 100644 --- a/src/geyser/core.zig +++ b/src/geyser/core.zig @@ -479,16 +479,11 @@ pub fn openPipe(pipe_path: []const u8) !std.fs.File { } pub fn streamReader( - allocator: std.mem.Allocator, + reader: *GeyserReader, logger: sig.trace.Logger, exit: *std.atomic.Value(bool), - pipe_path: []const u8, measure_rate: ?sig.time.Duration, - allocator_config: ?GeyserReader.AllocatorConfig, ) !void { - var reader = try sig.geyser.GeyserReader.init(allocator, pipe_path, exit, allocator_config orelse .{}); - defer reader.deinit(); - var bytes_read: usize = 0; var timer = try sig.time.Timer.start(); diff --git a/src/geyser/main.zig b/src/geyser/main.zig index 19a1f2822..1aea35c4a 100644 --- a/src/geyser/main.zig +++ b/src/geyser/main.zig @@ -356,15 +356,22 @@ pub fn benchmark() !void { logger.info().logf("using pipe path: {s}", .{pipe_path}); var exit = std.atomic.Value(bool).init(false); - try sig.geyser.core.streamReader( + + var reader = try sig.geyser.GeyserReader.init( allocator, - logger, - &exit, pipe_path, - sig.time.Duration.fromSecs(config.measure_rate_secs), + &exit, .{ .io_buf_len = 1 << 30, .bincode_buf_len = 1 << 30, }, ); + defer reader.deinit(); + + try sig.geyser.core.streamReader( + &reader, + logger, + &exit, + sig.time.Duration.fromSecs(config.measure_rate_secs), + ); } diff --git a/src/prometheus/registry.zig b/src/prometheus/registry.zig index 27afc2e6f..0422db708 100644 --- a/src/prometheus/registry.zig +++ b/src/prometheus/registry.zig @@ -238,13 +238,12 @@ pub fn Registry(comptime options: RegistryOptions) type { if (self.nbMetrics() >= options.max_metrics) return error.TooManyMetrics; if (name.len > options.max_name_len) return error.NameTooLong; - var allocator = self.arena_state.allocator(); - - const duped_name = try allocator.dupe(u8, name); - self.mutex.lock(); defer self.mutex.unlock(); + const allocator = self.arena_state.allocator(); + const duped_name = try allocator.dupe(u8, name); + const gop = try self.metrics.getOrPut(allocator, duped_name); if (!gop.found_existing) { var real_metric = try allocator.create(MetricType); diff --git a/src/shred_network/repair_service.zig b/src/shred_network/repair_service.zig index d491ee420..cf8acb508 100644 --- a/src/shred_network/repair_service.zig +++ b/src/shred_network/repair_service.zig @@ -102,6 +102,7 @@ pub const RepairService = struct { peer_provider: RepairPeerProvider, shred_tracker: *BasicShredTracker, ) !Self { + const n_threads = maxRequesterThreads(); return RepairService{ .allocator = allocator, .requester = requester, @@ -110,7 +111,7 @@ pub const RepairService = struct { .logger = logger.withScope(@typeName(Self)), .exit = exit, .report = MultiSlotReport.init(allocator), - .thread_pool = RequestBatchThreadPool.init(allocator, maxRequesterThreads()), + .thread_pool = try RequestBatchThreadPool.init(allocator, n_threads, n_threads), .metrics = try registry.initStruct(Metrics), .prng = std.Random.DefaultPrng.init(0), }; @@ -170,12 +171,12 @@ pub const RepairService = struct { for (0..num_threads) |i| { const start = (addressed_requests.items.len * i) / num_threads; const end = (addressed_requests.items.len * (i + 1)) / num_threads; - try self.thread_pool.schedule(.{ + self.thread_pool.schedule(.{ .requester = &self.requester, .requests = addressed_requests.items[start..end], }); - try self.thread_pool.joinFallible(); } + try self.thread_pool.joinFallible(); } return addressed_requests.items.len; diff --git a/src/utils/thread.zig b/src/utils/thread.zig index cbafc2096..27359bea2 100644 --- a/src/utils/thread.zig +++ b/src/utils/thread.zig @@ -1,6 +1,5 @@ const std = @import("std"); -const Allocator = std.mem.Allocator; const Condition = std.Thread.Condition; const Mutex = std.Thread.Mutex; @@ -14,25 +13,20 @@ pub const TaskParams = struct { }; fn chunkSizeAndThreadCount(data_len: usize, max_n_threads: usize) struct { usize, usize } { - var chunk_size = data_len / max_n_threads; var n_threads = max_n_threads; + var chunk_size = data_len / n_threads; if (chunk_size == 0) { + // default to one thread for all the data n_threads = 1; chunk_size = data_len; } return .{ chunk_size, n_threads }; } -pub fn SpawnThreadTasksConfig(comptime TaskFn: type) type { +pub fn SpawnThreadTasksParams(comptime TaskFn: type) type { return struct { - wg: *std.Thread.WaitGroup, data_len: usize, max_threads: usize, - /// If non-null, set to the coverage over the data which was achieved. - /// On a successful call, this will be equal to `data_len`. - /// On a failed call, this will be less than `data_len`, - /// representing the length of the data which was successfully - coverage: ?*usize = null, params: Params, pub const Params = std.meta.ArgsTuple(@Type(.{ .Fn = blk: { @@ -43,41 +37,46 @@ pub fn SpawnThreadTasksConfig(comptime TaskFn: type) type { }; } +/// this function spawns a number of threads to run the same task function. pub fn spawnThreadTasks( + allocator: std.mem.Allocator, comptime taskFn: anytype, - config: SpawnThreadTasksConfig(@TypeOf(taskFn)), -) std.Thread.SpawnError!void { - const Config = SpawnThreadTasksConfig(@TypeOf(taskFn)); + config: SpawnThreadTasksParams(@TypeOf(taskFn)), +) !void { const chunk_size, const n_threads = chunkSizeAndThreadCount(config.data_len, config.max_threads); - if (config.coverage) |coverage| coverage.* = 0; - const S = struct { - fn taskFnWg(wg: *std.Thread.WaitGroup, fn_params: Config.Params, task_params: TaskParams) @typeInfo(@TypeOf(taskFn)).Fn.return_type.? { - defer wg.finish(); - return @call(.auto, taskFn, fn_params ++ .{task_params}); + task_params: TaskParams, + fcn_params: @TypeOf(config).Params, + + fn run(self: *const @This()) @typeInfo(@TypeOf(taskFn)).Fn.return_type.? { + return @call(.auto, taskFn, self.fcn_params ++ .{self.task_params}); } }; + var thread_pool = try HomogeneousThreadPool(S).init( + allocator, + @intCast(n_threads), + n_threads, + ); + defer thread_pool.deinit(); + var start_index: usize = 0; for (0..n_threads) |thread_id| { const end_index = if (thread_id == n_threads - 1) config.data_len else (start_index + chunk_size); - const task_params: TaskParams = .{ - .start_index = start_index, - .end_index = end_index, - .thread_id = thread_id, - }; + thread_pool.schedule(.{ + .task_params = .{ + .start_index = start_index, + .end_index = end_index, + .thread_id = thread_id, + }, + .fcn_params = config.params, + }); - config.wg.start(); - const handle = std.Thread.spawn(.{}, S.taskFnWg, .{ config.wg, config.params, task_params }) catch |err| { - if (config.coverage) |coverage| coverage.* = start_index; - return err; - }; - handle.detach(); start_index = end_index; } - if (config.coverage) |coverage| coverage.* = config.data_len; + try thread_pool.joinFallible(); } pub fn ThreadPoolTask(comptime Entry: type) type { @@ -201,12 +200,16 @@ pub fn HomogeneousThreadPool(comptime TaskType: type) type { const Self = @This(); - pub fn init(allocator: std.mem.Allocator, num_threads: u32) Self { + pub fn init( + allocator: std.mem.Allocator, + num_threads: u32, + num_tasks: u64, + ) !Self { return .{ .allocator = allocator, .pool = ThreadPool.init(.{ .max_threads = num_threads }), - .tasks = std.ArrayList(TaskAdapter).init(allocator), - .results = std.ArrayList(TaskResult).init(allocator), + .tasks = try std.ArrayList(TaskAdapter).initCapacity(allocator, num_tasks), + .results = try std.ArrayList(TaskResult).initCapacity(allocator, num_tasks), }; } @@ -217,32 +220,77 @@ pub fn HomogeneousThreadPool(comptime TaskType: type) type { self.pool.deinit(); } - pub fn schedule(self: *Self, typed_task: TaskType) Allocator.Error!void { - const result = try self.results.addOne(); - var task = try self.tasks.addOne(); + pub fn schedule(self: *Self, typed_task: TaskType) void { + // NOTE: this breaks other pre-scheduled tasks on re-allocs so we dont + // allow re-allocations + const result = self.results.addOneAssumeCapacity(); + var task = self.tasks.addOneAssumeCapacity(); task.* = .{ .typed_task = typed_task, .result = result }; self.pool.schedule(Batch.from(&task.pool_task)); } /// blocks until all tasks are complete /// returns a list of any results for tasks that did not have a pointer provided - pub fn join(self: *Self) std.ArrayList(TaskResult) { + /// NOTE: if this fails then the result field is left in a bad state in which case the + /// thread pool should be discarded/reset + pub fn join(self: *Self) std.mem.Allocator.Error!std.ArrayList(TaskResult) { for (self.tasks.items) |*task| task.join(); const results = self.results; - self.results = std.ArrayList(TaskResult).init(self.allocator); + self.results = try std.ArrayList(TaskResult).initCapacity(self.allocator, self.tasks.capacity); self.tasks.clearRetainingCapacity(); return results; } /// Like join, but it returns an error if any tasks failed, and otherwise discards task output. + /// NOTE: this will return the first error encountered which may be inconsistent between runs. pub fn joinFallible(self: *Self) !void { - const results = self.join(); + const results = try self.join(); for (results.items) |result| try result; results.deinit(); } }; } +fn testSpawnThreadTasks( + values: []const u64, + sums: []u64, + task: TaskParams, +) !void { + std.debug.assert(@import("builtin").is_test); + var sum: u64 = 0; + for (task.start_index..task.end_index) |i| { + sum += values[i]; + } + sums[task.thread_id] = sum; +} + +test spawnThreadTasks { + const n_threads = 4; + const allocator = std.testing.allocator; + + const sums = try allocator.alloc(u64, n_threads); + defer allocator.free(sums); + + try spawnThreadTasks( + std.testing.allocator, + testSpawnThreadTasks, + .{ + .data_len = 10, + .max_threads = n_threads, + .params = .{ + &[_]u64{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, + sums, + }, + }, + ); + + var total_sum: u64 = 0; + for (sums) |sum| { + total_sum += sum; + } + try std.testing.expectEqual(55, total_sum); +} + test "typed thread pool" { const AdditionTask = struct { a: u64, @@ -252,13 +300,17 @@ test "typed thread pool" { } }; - var pool = HomogeneousThreadPool(AdditionTask).init(std.testing.allocator, 2); + var pool = try HomogeneousThreadPool(AdditionTask).init( + std.testing.allocator, + 2, + 3, + ); defer pool.deinit(); - try pool.schedule(.{ .a = 1, .b = 1 }); - try pool.schedule(.{ .a = 1, .b = 2 }); - try pool.schedule(.{ .a = 1, .b = 4 }); + pool.schedule(.{ .a = 1, .b = 1 }); + pool.schedule(.{ .a = 1, .b = 2 }); + pool.schedule(.{ .a = 1, .b = 4 }); - const results = pool.join(); + const results = try pool.join(); defer results.deinit(); try std.testing.expect(3 == results.items.len);