Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45506: [C++][Acero] More overflow-safe Swiss table #45515

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

zanmato1984
Copy link
Contributor

@zanmato1984 zanmato1984 commented Feb 12, 2025

Rationale for this change

See #45506.

What changes are included in this PR?

  1. Abstract current overflow-prone block data access into functions that do proper type promotion to avoid overflow. Also remove the old block base address accessor.
  2. Unify the data types used for various concepts as they naturally are (i.e., w/o explicit promotion): uint32_t for block_id, int for num_xxx_bits/bytes, uint32_t for group_id, int for local_slot_id and uint32_t for global_slot_id.
  3. Abstract several constants and utility functions for readability and maintainability.

Are these changes tested?

Existing tests should suffice.

It is really hard (gosh I did try) to create a concrete test case that fails w/o this change and passes w/ this change.

Are there any user-facing changes?

None.

@zanmato1984
Copy link
Contributor Author

Most of this change are cleanup and refinement. @pitrou mind to take a look? Thanks.

@github-actions github-actions bot added the awaiting review Awaiting review label Feb 12, 2025
@@ -643,37 +643,36 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
//
int source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After cleaning up these unnecessary 64-bit in this file, we can further cleanup some temp states as mentioned in #45336 (comment) .

@github-actions github-actions bot added awaiting committer review Awaiting committer review and removed awaiting review Awaiting review labels Feb 12, 2025
@zanmato1984 zanmato1984 force-pushed the more-overflow-safe-swiss-table branch from af1c470 to 1dd5c08 Compare February 12, 2025 12:53
@@ -81,18 +81,29 @@ class ARROW_EXPORT SwissTable {

void num_inserted(uint32_t i) { num_inserted_ = i; }

uint8_t* blocks() const { return blocks_->mutable_data(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the source of all evil. Let's get rid of it!

Comment on lines 119 to 122
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
block_data(block_id, num_block_bytes) + local_slots[id] * num_groupid_bytes +
bytes_status_in_block_);
group_id &= group_id_mask;
Copy link
Member

@pitrou pitrou Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we always issue a 32-bit load but then we optionally mask if the actual group id width is smaller? Don't we risk reading past block_data bounds here?

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we always issue a 32-bit load but then we optionally mask if the actual group id width is smaller? Don't we risk reading past block_data bounds here?

There will always be padding_ (64) extra bytes at the buffer end.

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

It seems so indeed, though I didn't change how the original code does it.

I'll update later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

It seems so indeed, though I didn't change how the original code does it.

I'll update later.

OK, turns out I was wrong. The original code uses aligned read and my change made it unaligned. I'll need to update it with more care. Thank you for pointing this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed back to the original aligned read with minor refinement.

Besides, I've also did some more cleanup during fixing the alignment issue. See my latest commit. Thanks.

return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1);
}

static constexpr int bytes_status_in_block_ = 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, compile-time constants should follow the naming convention kBytesStatusInBlock. Perhaps we can have a renaming pass in this PR or another one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is supposed to be. However I was also following the naming convention of several existing compile time constants in this class. I would like to to change them all in another PR to keep this one solely focused on the purpose the overflow prevention.

Comment on lines +387 to +389
uint32_t mask = num_groupid_bytes == 1 ? 0xFF
: num_groupid_bytes == 2 ? 0xFFFF
: 0xFFFFFFFF;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for expanding the possible values instead of simply using the usual bitshift formula?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly. This is just moving the original code.

__m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1);
__m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1);
local_slot_lo =
_mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(num_groupid_bytes));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using a 32-bit multiply even though local_slot_lo is supposed to be a vector of 64-bit ints? This might be correct because most bytes are zero, but I would at least expect an explanatory comment :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad, _mm256_mul_epu32 is actually a 64-bit multiply.

@zanmato1984 zanmato1984 force-pushed the more-overflow-safe-swiss-table branch from 1f9dde4 to aeb8b9d Compare February 13, 2025 14:40
uint64_t group_id_mask) const;
static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot,
int num_group_id_bits) {
// Extract group id using aligned 32-bit read.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use SafeLoad as in other places already?
(also, since this is non-trivial, factoring out the loading of a group id could go into a dedicated inline function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code uses aligned read + masking so I'm following it, possibly for performance sake I guess?

If SafeLoad is preferred (i.e., it doesn't hurt performance), then yes it is possible to factor this piece of code out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, there are three places doing group id extraction:

  1. Here, extracting single group id, publicly used by swiss join: currently using aligned read + masking;
  2. extract_group_ids, extracting a vector of group ids, internally used: using aligned read w/o masking (the number of bits is constant-ized as template parameter);
  3. grow_double, extracting single group id inside a big loop, inlined: using unaligned read + masking.

I think we should at least keep 2) as is because it makes perfect sense. 1) and 3) can be unified, either aligned or unaligned.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should at least keep 2) as is because it makes perfect sense. 1) and 3) can be unified, either aligned or unaligned.

Agreed. Feel free to choose whatever approach you prefer!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Thank you for the suggestion.

This PR is good for review again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants