-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GH-45506: [C++][Acero] More overflow-safe Swiss table #45515
Changes from all commits
1dd5c08
a1f9758
f5db159
7af1d3c
aeb8b9d
95c6990
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -643,37 +643,38 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc | |
// | ||
int source_group_id_bits = | ||
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks()); | ||
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits); | ||
int64_t source_block_bytes = source_group_id_bits + 8; | ||
int source_block_bytes = | ||
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits); | ||
uint32_t source_group_id_mask = | ||
SwissTable::group_id_mask_from_num_groupid_bits(source_group_id_bits); | ||
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0); | ||
|
||
// Compute index of the last block in target that corresponds to the given | ||
// partition. | ||
// | ||
ARROW_DCHECK(num_partition_bits <= target->log_blocks()); | ||
int64_t target_max_block_id = | ||
uint32_t target_max_block_id = | ||
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1; | ||
|
||
overflow_group_ids->clear(); | ||
overflow_hashes->clear(); | ||
|
||
// For each source block... | ||
int64_t source_blocks = 1LL << source->log_blocks(); | ||
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) { | ||
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes; | ||
uint32_t source_blocks = 1 << source->log_blocks(); | ||
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) { | ||
const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes); | ||
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes); | ||
|
||
// For each non-empty source slot... | ||
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; | ||
constexpr int kSlotsPerBlock = 8; | ||
int num_full_slots = | ||
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); | ||
int num_full_slots = SwissTable::kSlotsPerBlock - | ||
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); | ||
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) { | ||
// Read group id and hash for this slot. | ||
// | ||
uint64_t group_id = | ||
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask); | ||
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; | ||
uint32_t group_id = SwissTable::extract_group_id( | ||
block_bytes, local_slot_id, source_group_id_bits, source_group_id_mask); | ||
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id); | ||
uint32_t hash = source->hashes()[global_slot_id]; | ||
// Insert partition id into the highest bits of hash, shifting the | ||
// remaining hash bits right. | ||
|
@@ -696,17 +697,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc | |
} | ||
} | ||
|
||
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id, | ||
uint32_t hash, int64_t max_block_id) { | ||
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id, | ||
uint32_t hash, uint32_t max_block_id) { | ||
// Load the first block to visit for this hash | ||
// | ||
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks()); | ||
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1); | ||
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks()); | ||
uint32_t block_id_mask = (1 << target->log_blocks()) - 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would a signed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
int num_group_id_bits = | ||
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks()); | ||
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t); | ||
int num_block_bytes = | ||
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits); | ||
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0); | ||
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes; | ||
const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes); | ||
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes); | ||
|
||
// Search for the first block with empty slots. | ||
|
@@ -715,25 +717,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i | |
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; | ||
while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) { | ||
block_id = (block_id + 1) & block_id_mask; | ||
block_bytes = target->blocks() + block_id * num_block_bytes; | ||
block_bytes = target->block_data(block_id, num_block_bytes); | ||
block = *reinterpret_cast<const uint64_t*>(block_bytes); | ||
} | ||
if ((block & kHighBitOfEachByte) == 0) { | ||
return false; | ||
} | ||
constexpr int kSlotsPerBlock = 8; | ||
int local_slot_id = | ||
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); | ||
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id; | ||
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash, | ||
static_cast<uint32_t>(group_id)); | ||
int local_slot_id = SwissTable::kSlotsPerBlock - | ||
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte)); | ||
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id); | ||
target->insert_into_empty_slot(global_slot_id, hash, group_id); | ||
return true; | ||
} | ||
|
||
void SwissTableMerge::InsertNewGroups(SwissTable* target, | ||
const std::vector<uint32_t>& group_ids, | ||
const std::vector<uint32_t>& hashes) { | ||
int64_t num_blocks = 1LL << target->log_blocks(); | ||
uint32_t num_blocks = 1 << target->log_blocks(); | ||
for (size_t i = 0; i < group_ids.size(); ++i) { | ||
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks); | ||
} | ||
|
@@ -1191,7 +1191,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id, | |
// We want each partition to correspond to a range of block indices, | ||
// so we also partition on the highest bits of the hash. | ||
// | ||
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1; | ||
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_); | ||
}, | ||
[&locals](int64_t i, int pos) { | ||
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After cleaning up these unnecessary 64-bit in this file, we can further cleanup some temp states as mentioned in #45336 (comment) .