Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45506: [C++][Acero] More overflow-safe Swiss table #45515

Merged
merged 6 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,37 +643,38 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
//
int source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After cleaning up these unnecessary 64-bit in this file, we can further cleanup some temp states as mentioned in #45336 (comment) .

int64_t source_block_bytes = source_group_id_bits + 8;
int source_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
uint32_t source_group_id_mask =
SwissTable::group_id_mask_from_num_groupid_bits(source_group_id_bits);
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);

// Compute index of the last block in target that corresponds to the given
// partition.
//
ARROW_DCHECK(num_partition_bits <= target->log_blocks());
int64_t target_max_block_id =
uint32_t target_max_block_id =
((partition_id + 1) << (target->log_blocks() - num_partition_bits)) - 1;

overflow_group_ids->clear();
overflow_hashes->clear();

// For each source block...
int64_t source_blocks = 1LL << source->log_blocks();
for (int64_t block_id = 0; block_id < source_blocks; ++block_id) {
uint8_t* block_bytes = source->blocks() + block_id * source_block_bytes;
uint32_t source_blocks = 1 << source->log_blocks();
for (uint32_t block_id = 0; block_id < source_blocks; ++block_id) {
const uint8_t* block_bytes = source->block_data(block_id, source_block_bytes);
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);

// For each non-empty source slot...
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
constexpr int kSlotsPerBlock = 8;
int num_full_slots =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int num_full_slots = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
for (int local_slot_id = 0; local_slot_id < num_full_slots; ++local_slot_id) {
// Read group id and hash for this slot.
//
uint64_t group_id =
source->extract_group_id(block_bytes, local_slot_id, source_group_id_mask);
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
uint32_t group_id = SwissTable::extract_group_id(
block_bytes, local_slot_id, source_group_id_bits, source_group_id_mask);
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
uint32_t hash = source->hashes()[global_slot_id];
// Insert partition id into the highest bits of hash, shifting the
// remaining hash bits right.
Expand All @@ -696,17 +697,18 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
}
}

inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_id,
uint32_t hash, int64_t max_block_id) {
inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint32_t group_id,
uint32_t hash, uint32_t max_block_id) {
// Load the first block to visit for this hash
//
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a signed int shift be UB if target->log_blocks() is 31? Or does that not happen anyway?

Copy link
Contributor Author

@zanmato1984 zanmato1984 Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_blocks() is guaranteed to be <= 29: the maximum of number of rows of a swiss table is 2^32 (we already have many guards on this), and each block contains 8 rows/slots. So the UB won't be happening.

int num_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t);
int num_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits);
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes;
const uint8_t* block_bytes = target->block_data(block_id, num_block_bytes);
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);

// Search for the first block with empty slots.
Expand All @@ -715,25 +717,23 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;
while ((block & kHighBitOfEachByte) == 0 && block_id < max_block_id) {
block_id = (block_id + 1) & block_id_mask;
block_bytes = target->blocks() + block_id * num_block_bytes;
block_bytes = target->block_data(block_id, num_block_bytes);
block = *reinterpret_cast<const uint64_t*>(block_bytes);
}
if ((block & kHighBitOfEachByte) == 0) {
return false;
}
constexpr int kSlotsPerBlock = 8;
int local_slot_id =
kSlotsPerBlock - static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
int64_t global_slot_id = block_id * kSlotsPerBlock + local_slot_id;
target->insert_into_empty_slot(static_cast<uint32_t>(global_slot_id), hash,
static_cast<uint32_t>(group_id));
int local_slot_id = SwissTable::kSlotsPerBlock -
static_cast<int>(ARROW_POPCOUNT64(block & kHighBitOfEachByte));
uint32_t global_slot_id = SwissTable::global_slot_id(block_id, local_slot_id);
target->insert_into_empty_slot(global_slot_id, hash, group_id);
return true;
}

void SwissTableMerge::InsertNewGroups(SwissTable* target,
const std::vector<uint32_t>& group_ids,
const std::vector<uint32_t>& hashes) {
int64_t num_blocks = 1LL << target->log_blocks();
uint32_t num_blocks = 1 << target->log_blocks();
for (size_t i = 0; i < group_ids.size(); ++i) {
std::ignore = InsertNewGroup(target, group_ids[i], hashes[i], num_blocks);
}
Expand Down Expand Up @@ -1191,7 +1191,7 @@ Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
// We want each partition to correspond to a range of block indices,
// so we also partition on the highest bits of the hash.
//
return locals.batch_hashes[i] >> (31 - log_num_prtns_) >> 1;
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
},
[&locals](int64_t i, int pos) {
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ class SwissTableMerge {
// Max block id value greater or equal to the number of blocks guarantees that
// the search will not be stopped.
//
static inline bool InsertNewGroup(SwissTable* target, uint64_t group_id, uint32_t hash,
int64_t max_block_id);
static inline bool InsertNewGroup(SwissTable* target, uint32_t group_id, uint32_t hash,
uint32_t max_block_id);
};

struct SwissTableWithKeys {
Expand Down
Loading
Loading