From 1dd5c080708b4a1698066af9ead2ce64038dbda8 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 13 Jan 2025 23:05:49 +0800 Subject: [PATCH 1/6] More overflow-safe swiss table. --- cpp/src/arrow/acero/swiss_join.cc | 52 +++-- cpp/src/arrow/acero/swiss_join_internal.h | 4 +- cpp/src/arrow/compute/key_map_internal.cc | 204 ++++++++---------- cpp/src/arrow/compute/key_map_internal.h | 90 +++++--- .../arrow/compute/key_map_internal_avx2.cc | 62 ++++-- 5 files changed, 209 insertions(+), 203 deletions(-) diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 85e14ac469ce7..b1ba77216e1a5 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -643,37 +643,36 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc // int source_group_id_bits = SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks()); - uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits); - int64_t source_block_bytes = source_group_id_bits + 8; + int source_block_bytes = + SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits); ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0); // Compute index of the last block in target that corresponds to the given // partition. // ARROW_DCHECK(num_partition_bits <= target->log_blocks()); - int64_t target_max_block_id = + uint32_t target_max_block_id = ((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1; overflow_group_ids->clear(); overflow_hashes->clear(); // For each source block... - int64_t source_blocks = 1LL << source->log_blocks(); - for (int64_t block_id = 0; block_id < source_blocks; ++block_id) { - uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes; + uint32_t source_blocks = 1 << source->log_blocks(); + for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) { + const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes); uint64_t block = *reinterpret_cast(block_bytes); // For each non-empty source slot... constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; - constexpr int kSlotsPerBlock = 8; - int num_full_slots = - kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + int num_full_slots = SwissTable::kSlotsPerBlock - + static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) { // Read group id and hash for this slot. // - uint64_t group_id = - source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask); - int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; + uint32_t group_id = + source->extract_group_id(block_bytes, local_slot_id, source_group_id_bits); + uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id); uint32_t hash = source->hashes()[global_slot_id]; // Insert partition id into the highest bits of hash, shifting the // remaining hash bits right. @@ -696,17 +695,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc } } -inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id, - uint32_t hash, int64_t max_block_id) { +inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id, + uint32_t hash, uint32_t max_block_id) { // Load the first block to visit for this hash // - int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks()); - int64_t block_id_mask = ((1LL << target->log_blocks()) - 1); + uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks()); + uint32_t block_id_mask = (1 << target->log_blocks()) - 1; int num_group_id_bits = SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks()); - int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t); + int num_block_bytes = + SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits); ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); - uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes; + const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes); uint64_t block = *reinterpret_cast(block_bytes); // Search for the first block with empty slots. @@ -715,25 +715,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) { block_id = (block_id + 1) & block_id_mask; - block_bytes = target->blocks() + block_id * num_block_bytes; + block_bytes = target->block_data(block_id, num_block_bytes); block = *reinterpret_cast(block_bytes); } if ((block & kHighBitOfEachByte) == 0) { return false; } - constexpr int kSlotsPerBlock = 8; - int local_slot_id = - kSlotsPerBlock - static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); - int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; - target->insert_into_empty_slot(static_cast(global_slot_id), hash, - static_cast(group_id)); + int local_slot_id = SwissTable::kSlotsPerBlock - + static_cast(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id); + target->insert_into_empty_slot(global_slot_id, hash, group_id); return true; } void SwissTableMerge::InsertNewGroups(SwissTable* target, const std::vector& group_ids, const std::vector& hashes) { - int64_t num_blocks = 1LL << target->log_blocks(); + uint32_t num_blocks = 1 << target->log_blocks(); for (size_t i = 0; i < group_ids.size(); ++i) { std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks); } @@ -1191,7 +1189,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id, // We want each partition to correspond to a range of block indices, // so we also partition on the highest bits of the hash. // - return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1; + return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_); }, [&locals](int64_t i, int pos) { locals.batch_prtn_row_ids[pos] = static_cast(i); diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index 85f443b0323c7..d0d97aa1cc0fe 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -380,8 +380,8 @@ class SwissTableMerge { // Max block id value greater or equal to the number of blocks guarantees that // the search will not be stopped. // - static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash, - int64_t max_block_id); + static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash, + uint32_t max_block_id); }; struct SwissTableWithKeys { diff --git a/cpp/src/arrow/compute/key_map_internal.cc b/cpp/src/arrow/compute/key_map_internal.cc index ad264533bff30..860cebfd33ba6 100644 --- a/cpp/src/arrow/compute/key_map_internal.cc +++ b/cpp/src/arrow/compute/key_map_internal.cc @@ -94,27 +94,32 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -template +template void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, - uint32_t* out_group_ids, int element_offset, - int element_multiplier) const { - const T* elements = reinterpret_cast(blocks_->data()) + element_offset; + uint32_t* out_group_ids) const { if (log_blocks_ == 0) { - ARROW_DCHECK(sizeof(T) == sizeof(uint8_t)); for (int i = 0; i < num_keys; ++i) { uint32_t id = use_selection ? selection[i] : i; - uint32_t group_id = blocks()[8 + local_slots[id]]; + uint32_t group_id = + block_data(/*block_id=*/0, + /*num_block_bytes=*/0)[bytes_status_in_block_ + local_slots[id]]; out_group_ids[id] = group_id; } } else { + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + int num_groupid_bytes = num_groupid_bits / 8; + uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_groupid_bits); + int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); + for (int i = 0; i < num_keys; ++i) { uint32_t id = use_selection ? selection[i] : i; uint32_t hash = hashes[id]; - int64_t pos = - (hash >> (bits_hash_ - log_blocks_)) * element_multiplier + local_slots[id]; - uint32_t group_id = static_cast(elements[pos]); - ARROW_DCHECK(group_id < num_inserted_ || num_inserted_ == 0); + uint32_t block_id = block_id_from_hash(hash, log_blocks_); + uint32_t group_id = *reinterpret_cast( + block_data(block_id, num_block_bytes) + local_slots[id] * num_groupid_bytes + + bytes_status_in_block_); + group_id &= group_id_mask; out_group_ids[id] = group_id; } } @@ -123,59 +128,22 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_selection, const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const { - // Group id values for all 8 slots in the block are bit-packed and follow the status - // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In - // that case we can extract group id using aligned 64-bit word access. - int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); - ARROW_DCHECK(num_group_id_bits == 8 || num_group_id_bits == 16 || - num_group_id_bits == 32); - int num_processed = 0; - // Optimistically use simplified lookup involving only a start block to find // a single group id candidate for every input. #if defined(ARROW_HAVE_RUNTIME_AVX2) && defined(ARROW_HAVE_RUNTIME_BMI2) - int num_group_id_bytes = num_group_id_bits / 8; if ((hardware_flags_ & CpuInfo::AVX2) && CpuInfo::GetInstance()->HasEfficientBmi2() && !optional_selection) { - num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids, - sizeof(uint64_t), 8 + 8 * num_group_id_bytes, - num_group_id_bytes); + num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids); } #endif - switch (num_group_id_bits) { - case 8: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 8, 16); - } else { - extract_group_ids_imp( - num_keys - num_processed, nullptr, hashes + num_processed, - local_slots + num_processed, out_group_ids + num_processed, 8, 16); - } - break; - case 16: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 4, 12); - } else { - extract_group_ids_imp( - num_keys - num_processed, nullptr, hashes + num_processed, - local_slots + num_processed, out_group_ids + num_processed, 4, 12); - } - break; - case 32: - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, - local_slots, out_group_ids, 2, 10); - } else { - extract_group_ids_imp( - num_keys - num_processed, nullptr, hashes + num_processed, - local_slots + num_processed, out_group_ids + num_processed, 2, 10); - } - break; - default: - ARROW_DCHECK(false); + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, + out_group_ids); + } else { + extract_group_ids_imp(num_keys - num_processed, nullptr, + hashes + num_processed, local_slots + num_processed, + out_group_ids + num_processed); } } @@ -195,9 +163,9 @@ void SwissTable::init_slot_ids(const int num_keys, const uint16_t* selection, for (int i = 0; i < num_keys; ++i) { uint16_t id = selection[i]; uint32_t hash = hashes[id]; - uint32_t iblock = (hash >> (bits_hash_ - log_blocks_)); + uint32_t iblock = block_id_from_hash(hash, log_blocks_); uint32_t match = ::arrow::bit_util::GetBit(match_bitvector, id) ? 1 : 0; - uint32_t slot_id = iblock * 8 + local_slots[id] + match; + uint32_t slot_id = global_slot_id(iblock, local_slots[id] + match); out_slot_ids[id] = slot_id; } } @@ -207,11 +175,11 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id const uint32_t* hashes, uint32_t* slot_ids) const { int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - uint32_t num_block_bytes = num_groupid_bits + 8; + int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); if (log_blocks_ == 0) { uint64_t block = *reinterpret_cast(blocks_->mutable_data()); - uint32_t empty_slot = - static_cast(8 - ARROW_POPCOUNT64(block & kHighBitOfEachByte)); + uint32_t empty_slot = static_cast( + kSlotsPerBlock - ARROW_POPCOUNT64(block & kHighBitOfEachByte)); for (uint32_t i = 0; i < num_ids; ++i) { int id = ids[i]; slot_ids[id] = empty_slot; @@ -220,19 +188,18 @@ void SwissTable::init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t* id for (uint32_t i = 0; i < num_ids; ++i) { int id = ids[i]; uint32_t hash = hashes[id]; - uint32_t iblock = hash >> (bits_hash_ - log_blocks_); + uint32_t iblock = block_id_from_hash(hash, log_blocks_); uint64_t block; for (;;) { - block = *reinterpret_cast(blocks_->mutable_data() + - num_block_bytes * iblock); + block = *reinterpret_cast(block_data(iblock, num_block_bytes)); block &= kHighBitOfEachByte; if (block) { break; } iblock = (iblock + 1) & ((1 << log_blocks_) - 1); } - uint32_t empty_slot = static_cast(8 - ARROW_POPCOUNT64(block)); - slot_ids[id] = iblock * 8 + empty_slot; + uint32_t empty_slot = static_cast(kSlotsPerBlock - ARROW_POPCOUNT64(block)); + slot_ids[id] = global_slot_id(iblock, empty_slot); } } } @@ -249,6 +216,7 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes, // Based on the size of the table, prepare bit number constants. uint32_t stamp_mask = (1 << bits_stamp_) - 1; int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); for (int i = 0; i < num_keys; ++i) { // Extract from hash: block index and stamp @@ -258,9 +226,7 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes, uint32_t stamp = iblock & stamp_mask; iblock >>= bits_shift_for_block_; - uint32_t num_block_bytes = num_groupid_bits + 8; - const uint8_t* blockbase = - blocks_->data() + static_cast(iblock) * num_block_bytes; + const uint8_t* blockbase = block_data(iblock, num_block_bytes); ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); uint64_t block = *reinterpret_cast(blockbase); @@ -297,8 +263,8 @@ uint64_t SwissTable::num_groups_for_resize() const { } } -uint64_t SwissTable::wrap_global_slot_id(uint64_t global_slot_id) const { - uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1; +uint32_t SwissTable::wrap_global_slot_id(uint32_t global_slot_id) const { + uint32_t global_slot_id_mask = static_cast((1ULL << (log_blocks_ + 3)) - 1); return global_slot_id & global_slot_id_mask; } @@ -396,23 +362,22 @@ void SwissTable::run_comparisons(const int num_keys, bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id, uint32_t* out_slot_id, uint32_t* out_group_id) const { - const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); constexpr uint64_t stamp_mask = 0x7f; const int stamp = static_cast((hash >> bits_shift_for_block_and_stamp_) & stamp_mask); - uint64_t start_slot_id = wrap_global_slot_id(in_slot_id); + uint32_t start_slot_id = wrap_global_slot_id(in_slot_id); int match_found; int local_slot; - uint8_t* blockbase; + const uint8_t* blockbase; for (;;) { - const uint64_t num_block_bytes = (8 + num_groupid_bits); - blockbase = blocks_->mutable_data() + num_block_bytes * (start_slot_id >> 3); - uint64_t block = *reinterpret_cast(blockbase); + blockbase = block_data(start_slot_id >> 3, num_block_bytes); + uint64_t block = *reinterpret_cast(blockbase); - search_block(block, stamp, (start_slot_id & 7), &local_slot, &match_found); + search_block(block, stamp, start_slot_id & 7, &local_slot, &match_found); - start_slot_id = - wrap_global_slot_id((start_slot_id & ~7ULL) + local_slot + match_found); + start_slot_id = wrap_global_slot_id((start_slot_id & ~7U) + local_slot + match_found); // Match found can be 1 in two cases: // - match was found @@ -423,10 +388,8 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl } } - const uint64_t groupid_mask = (1ULL << num_groupid_bits) - 1; - *out_group_id = - static_cast(extract_group_id(blockbase, local_slot, groupid_mask)); - *out_slot_id = static_cast(start_slot_id); + *out_group_id = extract_group_id(blockbase, local_slot, num_groupid_bits); + *out_slot_id = start_slot_id; return match_found; } @@ -531,7 +494,7 @@ Status SwissTable::map_new_keys_helper( // ARROW_DCHECK(*inout_num_selected <= static_cast(1 << log_minibatch_)); - size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + sizeof(uint64_t); + size_t num_bytes_for_bits = (*inout_num_selected + 7) / 8 + bytes_status_in_block_; auto match_bitvector_buf = util::TempVectorHolder( temp_stack, static_cast(num_bytes_for_bits)); uint8_t* match_bitvector = match_bitvector_buf.mutable_data(); @@ -645,7 +608,8 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* for (uint32_t i = 0; i < num_ids; ++i) { // First slot in the new starting block const int16_t id = ids[i]; - slot_ids[id] = (hashes[id] >> (bits_hash_ - log_blocks_)) * 8; + uint32_t block_id = block_id_from_hash(hashes[id], log_blocks_); + slot_ids[id] = global_slot_id(block_id, 0); } } } while (num_ids > 0); @@ -662,10 +626,11 @@ Status SwissTable::grow_double() { int bits_shift_for_block_and_stamp_after = ComputeBitsShiftForBlockAndStamp(log_blocks_after); int bits_shift_for_block_after = ComputeBitsShiftForBlock(log_blocks_after); - uint64_t block_size_before = (8 + num_group_id_bits_before); - uint64_t block_size_after = (8 + num_group_id_bits_after); - uint64_t block_size_total_after = (block_size_after << log_blocks_after) + padding_; - uint64_t hashes_size_total_after = + int block_size_before = num_block_bytes_from_num_groupid_bits(num_group_id_bits_before); + int block_size_after = num_block_bytes_from_num_groupid_bits(num_group_id_bits_after); + int64_t block_size_total_after = + num_bytes_total_blocks(block_size_after, log_blocks_after); + int64_t hashes_size_total_after = (bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_; constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1; @@ -682,42 +647,42 @@ Status SwissTable::grow_double() { // (block other than selected by hash bits corresponding to the entry). for (int i = 0; i < (1 << log_blocks_); ++i) { // How many full slots in this block - uint8_t* block_base = blocks_->mutable_data() + i * block_size_before; + const uint8_t* block_base = block_data(i, block_size_before); uint8_t* double_block_base_new = - blocks_new->mutable_data() + 2 * i * block_size_after; + mutable_block_data(blocks_new->mutable_data(), 2 * i, block_size_after); uint64_t block = *reinterpret_cast(block_base); - auto full_slots = - static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); - int full_slots_new[2]; + uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3; + uint32_t full_slots_new[2]; full_slots_new[0] = full_slots_new[1] = 0; util::SafeStore(double_block_base_new, kHighBitOfEachByte); util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte); - for (int j = 0; j < full_slots; ++j) { - uint64_t slot_id = i * 8 + j; + for (uint32_t j = 0; j < full_slots; ++j) { + uint64_t slot_id = global_slot_id(i, j); uint32_t hash = hashes()[slot_id]; - uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after); + uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after); bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); if (is_overflow_entry) { continue; } - int ihalf = block_id_new & 1; + uint32_t ihalf = block_id_new & 1; uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask; uint64_t group_id_bit_offs = j * num_group_id_bits_before; uint64_t group_id = - (util::SafeLoadAs(block_base + 8 + (group_id_bit_offs >> 3)) >> + (util::SafeLoadAs(block_base + bytes_status_in_block_ + + (group_id_bit_offs >> 3)) >> (group_id_bit_offs & 7)) & group_id_mask_before; - uint64_t slot_id_new = i * 16 + ihalf * 8 + full_slots_new[ihalf]; + uint64_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]); hashes_new[slot_id_new] = hash; uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after; block_base_new[7 - full_slots_new[ihalf]] = stamp_new; - int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; - uint64_t* ptr = - reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)); + int64_t group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; + uint64_t* ptr = reinterpret_cast( + block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3)); util::SafeStore(ptr, util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); full_slots_new[ihalf]++; @@ -728,14 +693,14 @@ Status SwissTable::grow_double() { // Reinsert entries that were in an overflow block. for (int i = 0; i < (1 << log_blocks_); ++i) { // How many full slots in this block - uint8_t* block_base = blocks_->mutable_data() + i * block_size_before; + const uint8_t* block_base = block_data(i, block_size_before); uint64_t block = util::SafeLoadAs(block_base); - int full_slots = static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); + uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3; - for (int j = 0; j < full_slots; ++j) { - uint64_t slot_id = i * 8 + j; + for (uint32_t j = 0; j < full_slots; ++j) { + uint64_t slot_id = global_slot_id(i, j); uint32_t hash = hashes()[slot_id]; - uint64_t block_id_new = hash >> (bits_hash_ - log_blocks_after); + uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after); bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); if (!is_overflow_entry) { continue; @@ -743,17 +708,18 @@ Status SwissTable::grow_double() { uint64_t group_id_bit_offs = j * num_group_id_bits_before; uint64_t group_id = - (util::SafeLoadAs(block_base + 8 + (group_id_bit_offs >> 3)) >> + (util::SafeLoadAs(block_base + bytes_status_in_block_ + + (group_id_bit_offs >> 3)) >> (group_id_bit_offs & 7)) & group_id_mask_before; uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask; uint8_t* block_base_new = - blocks_new->mutable_data() + block_id_new * block_size_after; + mutable_block_data(blocks_new->mutable_data(), block_id_new, block_size_after); uint64_t block_new = util::SafeLoadAs(block_base_new); int full_slots_new = static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); - while (full_slots_new == 8) { + while (full_slots_new == kSlotsPerBlock) { block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1); block_base_new = blocks_new->mutable_data() + block_id_new * block_size_after; block_new = util::SafeLoadAs(block_base_new); @@ -763,9 +729,9 @@ Status SwissTable::grow_double() { hashes_new[block_id_new * 8 + full_slots_new] = hash; block_base_new[7 - full_slots_new] = stamp_new; - int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; - uint64_t* ptr = - reinterpret_cast(block_base_new + 8 + (group_id_bit_offs_new >> 3)); + int64_t group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; + uint64_t* ptr = reinterpret_cast( + block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3)); util::SafeStore(ptr, util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); } @@ -792,17 +758,17 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); num_inserted_ = 0; - const uint64_t block_bytes = 8 + num_groupid_bits; - const uint64_t slot_bytes = (block_bytes << log_blocks_) + padding_; + const int block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); + const int64_t slot_bytes = num_bytes_total_blocks(block_bytes, log_blocks_); ARROW_ASSIGN_OR_RAISE(blocks_, AllocateBuffer(slot_bytes, pool_)); // Make sure group ids are initially set to zero for all slots. memset(blocks_->mutable_data(), 0, slot_bytes); // Initialize all status bytes to represent an empty slot. - uint8_t* blocks_ptr = blocks_->mutable_data(); - for (uint64_t i = 0; i < (static_cast(1) << log_blocks_); ++i) { - util::SafeStore(blocks_ptr + i * block_bytes, kHighBitOfEachByte); + for (int i = 0; i < 1 << log_blocks_; ++i) { + auto block = mutable_block_data(i, block_bytes); + util::SafeStore(block, kHighBitOfEachByte); } if (no_hash_array) { diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index 66a9957006dd7..cd3ceee325fc5 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -81,18 +81,29 @@ class ARROW_EXPORT SwissTable { void num_inserted(uint32_t i) { num_inserted_ = i; } - uint8_t* blocks() const { return blocks_->mutable_data(); } - uint32_t* hashes() const { return reinterpret_cast(hashes_->mutable_data()); } + inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + /// \brief Extract group id for a given slot in a given block. /// - inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask) const; + static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot, + int64_t num_group_id_bits) { + uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits); + uint32_t group_id = *reinterpret_cast( + block_ptr + bytes_status_in_block_ + local_slot * num_group_id_bits / 8); + return group_id & group_id_mask; + } - inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + static uint32_t block_id_from_hash(uint32_t hash, int log_blocks) { + return hash >> (bits_hash_ - log_blocks); + } + + static uint32_t global_slot_id(uint32_t block_id, uint32_t local_slot_id) { + return block_id * static_cast(kSlotsPerBlock) + local_slot_id; + } static int num_groupid_bits_from_log_blocks(int log_blocks) { int required_bits = log_blocks + 3; @@ -102,10 +113,38 @@ class ARROW_EXPORT SwissTable { : 64; } + static int num_block_bytes_from_num_groupid_bits(int num_groupid_bits) { + return num_groupid_bits + bytes_status_in_block_; + } + + static int64_t num_bytes_total_blocks(int num_block_bytes, int log_blocks) { + return (static_cast(num_block_bytes) << log_blocks) + padding_; + } + + const uint8_t* block_data(uint32_t block_id, int num_block_bytes) const { + return block_data(blocks_->data(), block_id, num_block_bytes); + } + + uint8_t* mutable_block_data(uint32_t block_id, int num_block_bytes) { + return mutable_block_data(blocks_->mutable_data(), block_id, num_block_bytes); + } + + static constexpr int kSlotsPerBlock = 8; + // Use 32-bit hash for now static constexpr int bits_hash_ = 32; private: + static const uint8_t* block_data(const uint8_t* blocks, uint32_t block_id, + int num_block_bytes) { + return blocks + static_cast(block_id) * num_block_bytes; + } + + static uint8_t* mutable_block_data(uint8_t* blocks, uint32_t block_id, + int num_block_bytes) { + return blocks + static_cast(block_id) * num_block_bytes; + } + // Lookup helpers /// \brief Scan bytes in block in reverse and stop as soon @@ -139,18 +178,17 @@ class ARROW_EXPORT SwissTable { const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const; - template + template void extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, - uint32_t* out_group_ids, int elements_offset, - int element_multiplier) const; + uint32_t* out_group_ids) const; inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, int match_found) const; inline uint64_t num_groups_for_resize() const; - inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const; + inline uint32_t wrap_global_slot_id(uint32_t global_slot_id) const; void init_slot_ids(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, @@ -173,8 +211,7 @@ class ARROW_EXPORT SwissTable { uint8_t* out_match_bitvector, uint8_t* out_local_slots) const; int extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, - const uint8_t* local_slots, uint32_t* out_group_ids, - int byte_offset, int byte_multiplier, int byte_size) const; + const uint8_t* local_slots, uint32_t* out_group_ids) const; #endif void run_comparisons(const int num_keys, const uint16_t* optional_selection_ids, @@ -220,6 +257,12 @@ class ARROW_EXPORT SwissTable { return bits_stamp_; } + static uint32_t group_id_mask_from_num_groupid_bits(int64_t num_groupid_bits) { + return static_cast((1ULL << num_groupid_bits) - 1); + } + + static constexpr int bytes_status_in_block_ = 8; + // Number of hash bits stored in slots in a block. // The highest bits of hash determine block id. // The next set of highest bits is a "stamp" stored in a slot in a block. @@ -263,39 +306,22 @@ class ARROW_EXPORT SwissTable { MemoryPool* pool_; }; -uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, - uint64_t group_id_mask) const { - // Group id values for all 8 slots in the block are bit-packed and follow the status - // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In - // that case we can extract group id using aligned 64-bit word access. - int num_group_id_bits = static_cast(ARROW_POPCOUNT64(group_id_mask)); - assert(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32 || - num_group_id_bits == 64); - - int bit_offset = slot * num_group_id_bits; - const uint64_t* group_id_bytes = - reinterpret_cast(block_ptr) + 1 + (bit_offset >> 6); - uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask; - - return group_id; -} - void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id) { - const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); // We assume here that the number of bits is rounded up to 8, 16, 32 or 64. // In that case we can insert group id value using aligned 64-bit word access. assert(num_groupid_bits == 8 || num_groupid_bits == 16 || num_groupid_bits == 32 || num_groupid_bits == 64); - const uint64_t num_block_bytes = (8 + num_groupid_bits); + const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); constexpr uint64_t stamp_mask = 0x7f; int start_slot = (slot_id & 7); int stamp = static_cast((hash >> bits_shift_for_block_and_stamp_) & stamp_mask); - uint64_t block_id = slot_id >> 3; - uint8_t* blockbase = blocks_->mutable_data() + num_block_bytes * block_id; + uint32_t block_id = slot_id >> 3; + uint8_t* blockbase = mutable_block_data(block_id, num_block_bytes); blockbase[7 - start_slot] = static_cast(stamp); int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); diff --git a/cpp/src/arrow/compute/key_map_internal_avx2.cc b/cpp/src/arrow/compute/key_map_internal_avx2.cc index be54f7de63973..353d5a59e67d0 100644 --- a/cpp/src/arrow/compute/key_map_internal_avx2.cc +++ b/cpp/src/arrow/compute/key_map_internal_avx2.cc @@ -35,6 +35,7 @@ int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* h constexpr int unroll = 8; const int num_group_id_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_group_id_bits); const __m256i* vhash_ptr = reinterpret_cast(hashes); const __m256i vstamp_mask = _mm256_set1_epi32((1 << bits_stamp_) - 1); @@ -53,7 +54,7 @@ int SwissTable::early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* h // in order to process 64-bit blocks // __m256i vblock_offset = - _mm256_mullo_epi32(vblock_id, _mm256_set1_epi32(num_group_id_bits + 8)); + _mm256_mullo_epi32(vblock_id, _mm256_set1_epi32(num_block_bytes)); __m256i voffset_A = _mm256_and_si256(vblock_offset, _mm256_set1_epi64x(0xffffffff)); __m256i vstamp_A = _mm256_and_si256(vstamp, _mm256_set1_epi64x(0xffffffff)); __m256i voffset_B = _mm256_srli_epi64(vblock_offset, 32); @@ -230,9 +231,10 @@ int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* // Assemble the sequence of block bytes. uint64_t block_bytes[16]; const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); for (int i = 0; i < (1 << log_blocks_); ++i) { uint64_t in_blockbytes = - *reinterpret_cast(blocks_->data() + (8 + num_groupid_bits) * i); + *reinterpret_cast(block_data(i, num_block_bytes)); block_bytes[i] = in_blockbytes; } @@ -365,14 +367,9 @@ int SwissTable::early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashes, const uint8_t* local_slots, - uint32_t* out_group_ids, int byte_offset, - int byte_multiplier, int byte_size) const { - ARROW_DCHECK(byte_size == 1 || byte_size == 2 || byte_size == 4); - uint32_t mask = byte_size == 1 ? 0xFF : byte_size == 2 ? 0xFFFF : 0xFFFFFFFF; - auto elements = reinterpret_cast(blocks_->data() + byte_offset); + uint32_t* out_group_ids) const { constexpr int unroll = 8; if (log_blocks_ == 0) { - ARROW_DCHECK(byte_size == 1 && byte_offset == 8 && byte_multiplier == 16); __m256i block_group_ids = _mm256_set1_epi64x(reinterpret_cast(blocks_->data())[1]); for (int i = 0; i < num_keys / unroll; ++i) { @@ -385,33 +382,52 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } else { + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + int num_groupid_bytes = num_groupid_bits / 8; + uint32_t mask = num_groupid_bytes == 1 ? 0xFF + : num_groupid_bytes == 2 ? 0xFFFF + : 0xFFFFFFFF; + int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); + const int* slots_base = + reinterpret_cast(blocks_->data() + bytes_status_in_block_); + for (int i = 0; i < num_keys / unroll; ++i) { __m256i hash = _mm256_loadu_si256(reinterpret_cast(hashes) + i); - // Extend hash and local_slot to 64-bit to compute 64-bit group id offsets to - // gather from. This is to prevent index overflow issues in GH-44513. - // NB: Use zero-extend conversion for unsigned hash. - __m256i hash_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hash)); - __m256i hash_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hash, 1)); + __m256i block_id = + _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ - log_blocks_)); + __m256i local_slot = _mm256_set1_epi64x(reinterpret_cast(local_slots)[i]); + + // Extend block_id and local_slot to 64-bit to compute 64-bit group id offsets to + // gather from. This is to prevent index overflow issues in GH-44513. __m256i local_slot_lo = _mm256_shuffle_epi8( local_slot, _mm256_setr_epi32(0x80808000, 0x80808080, 0x80808001, 0x80808080, 0x80808002, 0x80808080, 0x80808003, 0x80808080)); __m256i local_slot_hi = _mm256_shuffle_epi8( local_slot, _mm256_setr_epi32(0x80808004, 0x80808080, 0x80808005, 0x80808080, 0x80808006, 0x80808080, 0x80808007, 0x80808080)); - local_slot_lo = _mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(byte_size)); - local_slot_hi = _mm256_mul_epu32(local_slot_hi, _mm256_set1_epi32(byte_size)); - __m256i pos_lo = _mm256_srli_epi64(hash_lo, bits_hash_ - log_blocks_); - __m256i pos_hi = _mm256_srli_epi64(hash_hi, bits_hash_ - log_blocks_); - pos_lo = _mm256_mul_epu32(pos_lo, _mm256_set1_epi32(byte_multiplier)); - pos_hi = _mm256_mul_epu32(pos_hi, _mm256_set1_epi32(byte_multiplier)); - pos_lo = _mm256_add_epi64(pos_lo, local_slot_lo); - pos_hi = _mm256_add_epi64(pos_hi, local_slot_hi); - __m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1); - __m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1); + local_slot_lo = + _mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(num_groupid_bytes)); + local_slot_hi = + _mm256_mul_epu32(local_slot_hi, _mm256_set1_epi32(num_groupid_bytes)); + + // NB: Use zero-extend conversion for unsigned block_id. + __m256i slot_offset_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(block_id)); + __m256i slot_offset_hi = + _mm256_cvtepu32_epi64(_mm256_extracti128_si256(block_id, 1)); + slot_offset_lo = + _mm256_mul_epi32(slot_offset_lo, _mm256_set1_epi64x(num_block_bytes)); + slot_offset_hi = + _mm256_mul_epi32(slot_offset_hi, _mm256_set1_epi64x(num_block_bytes)); + slot_offset_lo = _mm256_add_epi64(slot_offset_lo, local_slot_lo); + slot_offset_hi = _mm256_add_epi64(slot_offset_hi, local_slot_hi); + + __m128i group_id_lo = _mm256_i64gather_epi32(slots_base, slot_offset_lo, 1); + __m128i group_id_hi = _mm256_i64gather_epi32(slots_base, slot_offset_hi, 1); __m256i group_id = _mm256_set_m128i(group_id_hi, group_id_lo); group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i, group_id); } } From a1f9758fc24726c076909b9c9c30e213b3a28fcc Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Wed, 12 Feb 2025 20:59:58 +0800 Subject: [PATCH 2/6] Remove already implied casting --- cpp/src/arrow/compute/key_map_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index cd3ceee325fc5..459367a67e8de 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -102,7 +102,7 @@ class ARROW_EXPORT SwissTable { } static uint32_t global_slot_id(uint32_t block_id, uint32_t local_slot_id) { - return block_id * static_cast(kSlotsPerBlock) + local_slot_id; + return block_id * kSlotsPerBlock + local_slot_id; } static int num_groupid_bits_from_log_blocks(int log_blocks) { From f5db1597ac07bd78964ff6cbd02fd30b1f9141f1 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Wed, 12 Feb 2025 21:01:37 +0800 Subject: [PATCH 3/6] Revert some unintended reordering --- cpp/src/arrow/compute/key_map_internal.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index 459367a67e8de..8866c3566398a 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -85,8 +85,6 @@ class ARROW_EXPORT SwissTable { return reinterpret_cast(hashes_->mutable_data()); } - inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); - /// \brief Extract group id for a given slot in a given block. /// static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot, @@ -97,6 +95,8 @@ class ARROW_EXPORT SwissTable { return group_id & group_id_mask; } + inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + static uint32_t block_id_from_hash(uint32_t hash, int log_blocks) { return hash >> (bits_hash_ - log_blocks); } From 7af1d3c882f87a8c800e4167e7ae9284e9c2ac8a Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Thu, 13 Feb 2025 19:58:23 +0800 Subject: [PATCH 4/6] Use aligned read to extract group id --- cpp/src/arrow/compute/key_map_internal.cc | 54 +++++++++++++++++------ cpp/src/arrow/compute/key_map_internal.h | 21 ++++----- 2 files changed, 51 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/compute/key_map_internal.cc b/cpp/src/arrow/compute/key_map_internal.cc index 860cebfd33ba6..8440ada97c748 100644 --- a/cpp/src/arrow/compute/key_map_internal.cc +++ b/cpp/src/arrow/compute/key_map_internal.cc @@ -94,11 +94,12 @@ inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -template +template void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const { if (log_blocks_ == 0) { + DCHECK_EQ(sizeof(T), sizeof(uint8_t)); for (int i = 0; i < num_keys; ++i) { uint32_t id = use_selection ? selection[i] : i; uint32_t group_id = @@ -108,18 +109,16 @@ void SwissTable::extract_group_ids_imp(const int num_keys, const uint16_t* selec } } else { int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); - int num_groupid_bytes = num_groupid_bits / 8; - uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_groupid_bits); + DCHECK_EQ(sizeof(T) * 8, num_groupid_bits); int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); for (int i = 0; i < num_keys; ++i) { uint32_t id = use_selection ? selection[i] : i; uint32_t hash = hashes[id]; uint32_t block_id = block_id_from_hash(hash, log_blocks_); - uint32_t group_id = *reinterpret_cast( - block_data(block_id, num_block_bytes) + local_slots[id] * num_groupid_bytes + - bytes_status_in_block_); - group_id &= group_id_mask; + const T* slots_base = reinterpret_cast( + block_data(block_id, num_block_bytes) + bytes_status_in_block_); + uint32_t group_id = static_cast(slots_base[local_slots[id]]); out_group_ids[id] = group_id; } } @@ -137,13 +136,40 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_ num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids); } #endif - if (optional_selection) { - extract_group_ids_imp(num_keys, optional_selection, hashes, local_slots, - out_group_ids); - } else { - extract_group_ids_imp(num_keys - num_processed, nullptr, - hashes + num_processed, local_slots + num_processed, - out_group_ids + num_processed); + int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); + switch (num_groupid_bits) { + case 8: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids); + } else { + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed); + } + break; + case 16: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids); + } else { + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed); + } + break; + case 32: + if (optional_selection) { + extract_group_ids_imp(num_keys, optional_selection, hashes, + local_slots, out_group_ids); + } else { + extract_group_ids_imp( + num_keys - num_processed, nullptr, hashes + num_processed, + local_slots + num_processed, out_group_ids + num_processed); + } + break; + default: + DCHECK(false); } } diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index 8866c3566398a..b3849b2660e10 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -88,11 +88,15 @@ class ARROW_EXPORT SwissTable { /// \brief Extract group id for a given slot in a given block. /// static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot, - int64_t num_group_id_bits) { - uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits); - uint32_t group_id = *reinterpret_cast( - block_ptr + bytes_status_in_block_ + local_slot * num_group_id_bits / 8); - return group_id & group_id_mask; + int num_group_id_bits) { + // Extract group id using aligned 32-bit read. + uint32_t group_id_mask = static_cast((1ULL << num_group_id_bits) - 1); + int slot_bit_offset = local_slot * num_group_id_bits; + const uint32_t* group_id_ptr32 = + reinterpret_cast(block_ptr + bytes_status_in_block_) + + (slot_bit_offset >> 5); + uint32_t group_id = (*group_id_ptr32 >> (slot_bit_offset & 31)) & group_id_mask; + return group_id; } inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); @@ -106,6 +110,7 @@ class ARROW_EXPORT SwissTable { } static int num_groupid_bits_from_log_blocks(int log_blocks) { + assert(log_blocks >= 0 && log_blocks <= 32 - 3); int required_bits = log_blocks + 3; return required_bits <= 8 ? 8 : required_bits <= 16 ? 16 @@ -178,7 +183,7 @@ class ARROW_EXPORT SwissTable { const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const; - template + template void extract_group_ids_imp(const int num_keys, const uint16_t* selection, const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const; @@ -257,10 +262,6 @@ class ARROW_EXPORT SwissTable { return bits_stamp_; } - static uint32_t group_id_mask_from_num_groupid_bits(int64_t num_groupid_bits) { - return static_cast((1ULL << num_groupid_bits) - 1); - } - static constexpr int bytes_status_in_block_ = 8; // Number of hash bits stored in slots in a block. From aeb8b9d0fbe1729b1f76fd26dd65397c058eab68 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Thu, 13 Feb 2025 22:28:10 +0800 Subject: [PATCH 5/6] Some more cleanup found in the last commit --- cpp/src/arrow/compute/key_map_internal.cc | 66 ++++++++++++----------- cpp/src/arrow/compute/key_map_internal.h | 56 +++++++++++-------- 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/compute/key_map_internal.cc b/cpp/src/arrow/compute/key_map_internal.cc index 8440ada97c748..910e2944ed84c 100644 --- a/cpp/src/arrow/compute/key_map_internal.cc +++ b/cpp/src/arrow/compute/key_map_internal.cc @@ -272,14 +272,14 @@ void SwissTable::early_filter_imp(const int num_keys, const uint32_t* hashes, // How many groups we can keep in the hash table without the need for resizing. // When we reach this limit, we need to break processing of any further rows and resize. // -uint64_t SwissTable::num_groups_for_resize() const { +int64_t SwissTable::num_groups_for_resize() const { // Consider N = 9 (aka 2 ^ 9 = 512 blocks) as small. // When N = 9, a slot id takes N + 3 = 12 bits, rounded up to 16 bits. This is also the // number of bits needed for a key id. Since each slot stores a status byte and a key // id, then a slot takes 1 byte + 16 bits = 3 bytes. Therefore a block of 8 slots takes // 24 bytes. The threshold of a small hash table ends up being 24 bytes * 512 = 12 KB. constexpr int log_blocks_small_ = 9; - uint64_t num_slots = 1ULL << (log_blocks_ + 3); + int64_t num_slots = num_slots_from_log_blocks(log_blocks_); if (log_blocks_ <= log_blocks_small_) { // Resize small hash tables when 50% full. return num_slots / 2; @@ -290,7 +290,8 @@ uint64_t SwissTable::num_groups_for_resize() const { } uint32_t SwissTable::wrap_global_slot_id(uint32_t global_slot_id) const { - uint32_t global_slot_id_mask = static_cast((1ULL << (log_blocks_ + 3)) - 1); + uint32_t global_slot_id_mask = + static_cast((1ULL << (log_blocks_ + kLogSlotsPerBlock)) - 1ULL); return global_slot_id & global_slot_id_mask; } @@ -398,18 +399,20 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl int local_slot; const uint8_t* blockbase; for (;;) { - blockbase = block_data(start_slot_id >> 3, num_block_bytes); + blockbase = block_data(start_slot_id >> kLogSlotsPerBlock, num_block_bytes); uint64_t block = *reinterpret_cast(blockbase); - search_block(block, stamp, start_slot_id & 7, &local_slot, &match_found); + search_block(block, stamp, start_slot_id & kLocalSlotMask, &local_slot, + &match_found); - start_slot_id = wrap_global_slot_id((start_slot_id & ~7U) + local_slot + match_found); + start_slot_id = + wrap_global_slot_id((start_slot_id & ~kLocalSlotMask) + local_slot + match_found); // Match found can be 1 in two cases: // - match was found // - match was not found in a full block // In the second case search needs to continue in the next block. - if (match_found == 0 || blockbase[7 - local_slot] == stamp) { + if (match_found == 0 || blockbase[kMaxLocalSlot - local_slot] == stamp) { break; } } @@ -635,7 +638,7 @@ Status SwissTable::map_new_keys(uint32_t num_ids, uint16_t* ids, const uint32_t* // First slot in the new starting block const int16_t id = ids[i]; uint32_t block_id = block_id_from_hash(hashes[id], log_blocks_); - slot_ids[id] = global_slot_id(block_id, 0); + slot_ids[id] = global_slot_id(block_id, /*local_slot_id=*/0); } } } while (num_ids > 0); @@ -647,7 +650,8 @@ Status SwissTable::grow_double() { // Before and after metadata int num_group_id_bits_before = num_groupid_bits_from_log_blocks(log_blocks_); int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ + 1); - uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before); + uint32_t group_id_mask_before = + group_id_mask_from_num_groupid_bits(num_group_id_bits_before); int log_blocks_after = log_blocks_ + 1; int bits_shift_for_block_and_stamp_after = ComputeBitsShiftForBlockAndStamp(log_blocks_after); @@ -657,7 +661,7 @@ Status SwissTable::grow_double() { int64_t block_size_total_after = num_bytes_total_blocks(block_size_after, log_blocks_after); int64_t hashes_size_total_after = - (bits_hash_ / 8 * (1 << (log_blocks_after + 3))) + padding_; + (bits_hash_ / 8 * num_slots_from_log_blocks(log_blocks_after)) + padding_; constexpr uint32_t stamp_mask = (1 << bits_stamp_) - 1; // Allocate new buffers @@ -685,7 +689,7 @@ Status SwissTable::grow_double() { util::SafeStore(double_block_base_new + block_size_after, kHighBitOfEachByte); for (uint32_t j = 0; j < full_slots; ++j) { - uint64_t slot_id = global_slot_id(i, j); + uint32_t slot_id = global_slot_id(i, j); uint32_t hash = hashes()[slot_id]; uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after); bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); @@ -695,22 +699,22 @@ Status SwissTable::grow_double() { uint32_t ihalf = block_id_new & 1; uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask; - uint64_t group_id_bit_offs = j * num_group_id_bits_before; - uint64_t group_id = - (util::SafeLoadAs(block_base + bytes_status_in_block_ + + int group_id_bit_offs = j * num_group_id_bits_before; + uint32_t group_id = + (util::SafeLoadAs(block_base + bytes_status_in_block_ + (group_id_bit_offs >> 3)) >> (group_id_bit_offs & 7)) & group_id_mask_before; - uint64_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]); + uint32_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]); hashes_new[slot_id_new] = hash; uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after; - block_base_new[7 - full_slots_new[ihalf]] = stamp_new; - int64_t group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; + block_base_new[kMaxLocalSlot - full_slots_new[ihalf]] = stamp_new; + int group_id_bit_offs_new = full_slots_new[ihalf] * num_group_id_bits_after; uint64_t* ptr = reinterpret_cast( block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3)); - util::SafeStore(ptr, - util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); + util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast(group_id) + << (group_id_bit_offs_new & 7))); full_slots_new[ihalf]++; } } @@ -724,7 +728,7 @@ Status SwissTable::grow_double() { uint32_t full_slots = CountLeadingZeros(block & kHighBitOfEachByte) >> 3; for (uint32_t j = 0; j < full_slots; ++j) { - uint64_t slot_id = global_slot_id(i, j); + uint32_t slot_id = global_slot_id(i, j); uint32_t hash = hashes()[slot_id]; uint32_t block_id_new = block_id_from_hash(hash, log_blocks_after); bool is_overflow_entry = ((block_id_new >> 1) != static_cast(i)); @@ -732,9 +736,9 @@ Status SwissTable::grow_double() { continue; } - uint64_t group_id_bit_offs = j * num_group_id_bits_before; - uint64_t group_id = - (util::SafeLoadAs(block_base + bytes_status_in_block_ + + int group_id_bit_offs = j * num_group_id_bits_before; + uint32_t group_id = + (util::SafeLoadAs(block_base + bytes_status_in_block_ + (group_id_bit_offs >> 3)) >> (group_id_bit_offs & 7)) & group_id_mask_before; @@ -753,13 +757,13 @@ Status SwissTable::grow_double() { static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); } - hashes_new[block_id_new * 8 + full_slots_new] = hash; - block_base_new[7 - full_slots_new] = stamp_new; - int64_t group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; + hashes_new[block_id_new * kSlotsPerBlock + full_slots_new] = hash; + block_base_new[kMaxLocalSlot - full_slots_new] = stamp_new; + int group_id_bit_offs_new = full_slots_new * num_group_id_bits_after; uint64_t* ptr = reinterpret_cast( block_base_new + bytes_status_in_block_ + (group_id_bit_offs_new >> 3)); - util::SafeStore(ptr, - util::SafeLoad(ptr) | (group_id << (group_id_bit_offs_new & 7))); + util::SafeStore(ptr, util::SafeLoad(ptr) | (static_cast(group_id) + << (group_id_bit_offs_new & 7))); } } @@ -800,9 +804,9 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool* pool, int log_blocks if (no_hash_array) { hashes_ = nullptr; } else { - uint64_t num_slots = 1ULL << (log_blocks_ + 3); - const uint64_t hash_size = sizeof(uint32_t); - const uint64_t hash_bytes = hash_size * num_slots + padding_; + int64_t num_slots = num_slots_from_log_blocks(log_blocks); + const int hash_size = bits_hash_ >> 3; + const int64_t hash_bytes = hash_size * num_slots + padding_; ARROW_ASSIGN_OR_RAISE(hashes_, AllocateBuffer(hash_bytes, pool_)); } diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index b3849b2660e10..a6d136b95f2f5 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -90,7 +90,7 @@ class ARROW_EXPORT SwissTable { static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot, int num_group_id_bits) { // Extract group id using aligned 32-bit read. - uint32_t group_id_mask = static_cast((1ULL << num_group_id_bits) - 1); + uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits); int slot_bit_offset = local_slot * num_group_id_bits; const uint32_t* group_id_ptr32 = reinterpret_cast(block_ptr + bytes_status_in_block_) + @@ -99,7 +99,8 @@ class ARROW_EXPORT SwissTable { return group_id; } - inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id); + inline void insert_into_empty_slot(uint32_t global_slot_id, uint32_t hash, + uint32_t group_id); static uint32_t block_id_from_hash(uint32_t hash, int log_blocks) { return hash >> (bits_hash_ - log_blocks); @@ -110,22 +111,16 @@ class ARROW_EXPORT SwissTable { } static int num_groupid_bits_from_log_blocks(int log_blocks) { - assert(log_blocks >= 0 && log_blocks <= 32 - 3); - int required_bits = log_blocks + 3; - return required_bits <= 8 ? 8 - : required_bits <= 16 ? 16 - : required_bits <= 32 ? 32 - : 64; + assert(log_blocks >= 0); + int required_bits = log_blocks + kLogSlotsPerBlock; + assert(required_bits <= 32); + return required_bits <= 8 ? 8 : required_bits <= 16 ? 16 : 32; } static int num_block_bytes_from_num_groupid_bits(int num_groupid_bits) { return num_groupid_bits + bytes_status_in_block_; } - static int64_t num_bytes_total_blocks(int num_block_bytes, int log_blocks) { - return (static_cast(num_block_bytes) << log_blocks) + padding_; - } - const uint8_t* block_data(uint32_t block_id, int num_block_bytes) const { return block_data(blocks_->data(), block_id, num_block_bytes); } @@ -188,10 +183,24 @@ class ARROW_EXPORT SwissTable { const uint32_t* hashes, const uint8_t* local_slots, uint32_t* out_group_ids) const; - inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, - int match_found) const; + static constexpr int kLogSlotsPerBlock = 3; + static constexpr int kMaxLocalSlot = kSlotsPerBlock - 1; + static constexpr uint32_t kLocalSlotMask = (1U << kLogSlotsPerBlock) - 1U; + + static int64_t num_slots_from_log_blocks(int log_blocks) { + return 1LL << (log_blocks + kLogSlotsPerBlock); + } + + static int64_t num_bytes_total_blocks(int num_block_bytes, int log_blocks) { + return (static_cast(num_block_bytes) << log_blocks) + padding_; + } + + inline int64_t num_groups_for_resize() const; - inline uint64_t num_groups_for_resize() const; + static uint32_t group_id_mask_from_num_groupid_bits(int num_groupid_bits) { + // num_groupid_bits could be 32, so using 64-bit shifting. + return static_cast((1ULL << num_groupid_bits) - 1ULL); + } inline uint32_t wrap_global_slot_id(uint32_t global_slot_id) const; @@ -307,7 +316,7 @@ class ARROW_EXPORT SwissTable { MemoryPool* pool_; }; -void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, +void SwissTable::insert_into_empty_slot(uint32_t global_slot_id, uint32_t hash, uint32_t group_id) { const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); @@ -317,19 +326,20 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash, num_groupid_bits == 64); const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); - constexpr uint64_t stamp_mask = 0x7f; + constexpr uint32_t stamp_mask = 0x7f; - int start_slot = (slot_id & 7); - int stamp = static_cast((hash >> bits_shift_for_block_and_stamp_) & stamp_mask); - uint32_t block_id = slot_id >> 3; + int start_slot = (global_slot_id & kLocalSlotMask); + int stamp = (hash >> bits_shift_for_block_and_stamp_) & stamp_mask; + uint32_t block_id = global_slot_id >> kLogSlotsPerBlock; uint8_t* blockbase = mutable_block_data(block_id, num_block_bytes); - blockbase[7 - start_slot] = static_cast(stamp); - int groupid_bit_offset = static_cast(start_slot * num_groupid_bits); + blockbase[kMaxLocalSlot - start_slot] = static_cast(stamp); + int groupid_bit_offset = start_slot * num_groupid_bits; // Block status bytes should start at an address aligned to 8 bytes assert((reinterpret_cast(blockbase) & 7) == 0); - uint64_t* ptr = reinterpret_cast(blockbase) + 1 + (groupid_bit_offset >> 6); + uint64_t* ptr = reinterpret_cast(blockbase + bytes_status_in_block_) + + (groupid_bit_offset >> 6); *ptr |= (static_cast(group_id) << (groupid_bit_offset & 63)); } From 95c6990a212de7fede85703345a0e23df888f491 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Fri, 14 Feb 2025 03:00:49 +0800 Subject: [PATCH 6/6] Unify extracting group id code by using aligned 32-bit read --- cpp/src/arrow/acero/swiss_join.cc | 6 ++++-- cpp/src/arrow/compute/key_map_internal.cc | 17 +++++------------ cpp/src/arrow/compute/key_map_internal.h | 20 +++++++++++--------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index b1ba77216e1a5..d99c2f08c0134 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -645,6 +645,8 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks()); int source_block_bytes = SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits); + uint32_t source_group_id_mask = + SwissTable::group_id_mask_from_num_groupid_bits(source_group_id_bits); ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0); // Compute index of the last block in target that corresponds to the given @@ -670,8 +672,8 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) { // Read group id and hash for this slot. // - uint32_t group_id = - source->extract_group_id(block_bytes, local_slot_id, source_group_id_bits); + uint32_t group_id = SwissTable::extract_group_id( + block_bytes, local_slot_id, source_group_id_bits, source_group_id_mask); uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id); uint32_t hash = source->hashes()[global_slot_id]; // Insert partition id into the highest bits of hash, shifting the diff --git a/cpp/src/arrow/compute/key_map_internal.cc b/cpp/src/arrow/compute/key_map_internal.cc index 910e2944ed84c..e44177d6a6f91 100644 --- a/cpp/src/arrow/compute/key_map_internal.cc +++ b/cpp/src/arrow/compute/key_map_internal.cc @@ -391,6 +391,7 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl uint32_t* out_group_id) const { const int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_); const int num_block_bytes = num_block_bytes_from_num_groupid_bits(num_groupid_bits); + const int group_id_mask = group_id_mask_from_num_groupid_bits(num_groupid_bits); constexpr uint64_t stamp_mask = 0x7f; const int stamp = static_cast((hash >> bits_shift_for_block_and_stamp_) & stamp_mask); @@ -417,7 +418,8 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash, const uint32_t in_sl } } - *out_group_id = extract_group_id(blockbase, local_slot, num_groupid_bits); + *out_group_id = + extract_group_id(blockbase, local_slot, num_groupid_bits, group_id_mask); *out_slot_id = start_slot_id; return match_found; @@ -699,13 +701,8 @@ Status SwissTable::grow_double() { uint32_t ihalf = block_id_new & 1; uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask; - int group_id_bit_offs = j * num_group_id_bits_before; uint32_t group_id = - (util::SafeLoadAs(block_base + bytes_status_in_block_ + - (group_id_bit_offs >> 3)) >> - (group_id_bit_offs & 7)) & - group_id_mask_before; - + extract_group_id(block_base, j, num_group_id_bits_before, group_id_mask_before); uint32_t slot_id_new = global_slot_id(i * 2 + ihalf, full_slots_new[ihalf]); hashes_new[slot_id_new] = hash; uint8_t* block_base_new = double_block_base_new + ihalf * block_size_after; @@ -736,12 +733,8 @@ Status SwissTable::grow_double() { continue; } - int group_id_bit_offs = j * num_group_id_bits_before; uint32_t group_id = - (util::SafeLoadAs(block_base + bytes_status_in_block_ + - (group_id_bit_offs >> 3)) >> - (group_id_bit_offs & 7)) & - group_id_mask_before; + extract_group_id(block_base, j, num_group_id_bits_before, group_id_mask_before); uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) & stamp_mask; uint8_t* block_base_new = diff --git a/cpp/src/arrow/compute/key_map_internal.h b/cpp/src/arrow/compute/key_map_internal.h index a6d136b95f2f5..8423134cb3269 100644 --- a/cpp/src/arrow/compute/key_map_internal.h +++ b/cpp/src/arrow/compute/key_map_internal.h @@ -85,12 +85,14 @@ class ARROW_EXPORT SwissTable { return reinterpret_cast(hashes_->mutable_data()); } - /// \brief Extract group id for a given slot in a given block. + /// \brief Extract group id for a given slot in a given block using aligned 32-bit read + /// regardless of the number of group id bits. + /// Note that group_id_mask should be derived from num_group_id_bits. This function + /// accepts both and does debug checking for performance sake. /// static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot, - int num_group_id_bits) { - // Extract group id using aligned 32-bit read. - uint32_t group_id_mask = group_id_mask_from_num_groupid_bits(num_group_id_bits); + int num_group_id_bits, uint32_t group_id_mask) { + assert(group_id_mask_from_num_groupid_bits(num_group_id_bits) == group_id_mask); int slot_bit_offset = local_slot * num_group_id_bits; const uint32_t* group_id_ptr32 = reinterpret_cast(block_ptr + bytes_status_in_block_) + @@ -121,6 +123,11 @@ class ARROW_EXPORT SwissTable { return num_groupid_bits + bytes_status_in_block_; } + static uint32_t group_id_mask_from_num_groupid_bits(int num_groupid_bits) { + // num_groupid_bits could be 32, so using 64-bit shifting. + return static_cast((1ULL << num_groupid_bits) - 1ULL); + } + const uint8_t* block_data(uint32_t block_id, int num_block_bytes) const { return block_data(blocks_->data(), block_id, num_block_bytes); } @@ -197,11 +204,6 @@ class ARROW_EXPORT SwissTable { inline int64_t num_groups_for_resize() const; - static uint32_t group_id_mask_from_num_groupid_bits(int num_groupid_bits) { - // num_groupid_bits could be 32, so using 64-bit shifting. - return static_cast((1ULL << num_groupid_bits) - 1ULL); - } - inline uint32_t wrap_global_slot_id(uint32_t global_slot_id) const; void init_slot_ids(const int num_keys, const uint16_t* selection,