Skip to content

Commit

Permalink
Enable parallel dataset writing + various nits
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Feb 12, 2025
1 parent b24b8c2 commit b3a7585
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 70 deletions.
128 changes: 74 additions & 54 deletions cpp/src/arrow/dataset/file_parquet_encryption_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "gtest/gtest.h"

#include "arrow/array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_parquet.h"
Expand All @@ -29,6 +30,7 @@
#include "arrow/io/api.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/testing/future_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/type.h"
Expand Down Expand Up @@ -58,9 +60,26 @@ struct EncryptionTestParam {
bool concurrently;
};

std::string PrintParam(const testing::TestParamInfo<EncryptionTestParam>& info) {
std::string out;
out += info.param.uniform_encryption ? "UniformEncryption" : "ColumnKeys";
out += info.param.concurrently ? "Threaded" : "Serial";
return out;
}

const auto kAllParamValues =
::testing::Values(EncryptionTestParam{false, false}, EncryptionTestParam{true, false},
EncryptionTestParam{false, true}, EncryptionTestParam{true, true});

// Base class to test writing and reading encrypted dataset.
class DatasetEncryptionTestBase : public testing::TestWithParam<EncryptionTestParam> {
public:
#ifdef ARROW_VALGRIND
static constexpr int kConcurrentIterations = 4;
#else
static constexpr int kConcurrentIterations = 20;
#endif

// This function creates a mock file system using the current time point, creates a
// directory with the given base directory path, and writes a dataset to it using
// provided Parquet file write options. The function also checks if the written files
Expand All @@ -80,6 +99,8 @@ class DatasetEncryptionTestBase : public testing::TestWithParam<EncryptionTestPa

// Init dataset and partitioning.
ASSERT_NO_FATAL_FAILURE(PrepareTableAndPartitioning());
ASSERT_OK_AND_ASSIGN(expected_table_, table_->CombineChunks());
ASSERT_OK_AND_ASSIGN(expected_table_, SortTable(expected_table_));

// Prepare encryption properties.
std::unordered_map<std::string, std::string> key_map;
Expand Down Expand Up @@ -119,31 +140,29 @@ class DatasetEncryptionTestBase : public testing::TestWithParam<EncryptionTestPa
EXPECT_OK_AND_ASSIGN(auto scanner_builder, dataset->NewScan());
// ideally, we would have UseThreads(concurrently) here, but that is not working
// unless GH-26818 (https://github.com/apache/arrow/issues/26818) is fixed
ARROW_EXPECT_OK(scanner_builder->UseThreads(false));
ARROW_EXPECT_OK(scanner_builder->UseThreads(GetParam().concurrently));
EXPECT_OK_AND_ASSIGN(auto scanner, scanner_builder->Finish());

if (GetParam().concurrently) {
// have a notable number of threads to exhibit multi-threading issues
// Have a notable number of threads to exhibit multi-threading issues
ASSERT_OK_AND_ASSIGN(auto pool, arrow::internal::ThreadPool::Make(16));
std::vector<Future<>> threads;
std::vector<Future<>> futures;

// write dataset above multiple times concurrently to see that is thread-safe.
for (size_t i = 1; i <= 100; ++i) {
// Write dataset above multiple times concurrently to see that is thread-safe.
for (int i = 1; i <= kConcurrentIterations; ++i) {
FileSystemDatasetWriteOptions write_options;
write_options.file_write_options = parquet_file_write_options;
write_options.filesystem = file_system_;
write_options.base_dir = "thread-" + std::to_string(i);
write_options.partitioning = partitioning_;
write_options.basename_template = "part{i}.parquet";
threads.push_back(
futures.push_back(
DeferNotOk(pool->Submit(FileSystemDataset::Write, write_options, scanner)));
}
pool->WaitForIdle();

// assert all jobs succeeded
for (auto& thread : threads) {
thread.Wait();
ASSERT_TRUE(thread.state() == FutureState::SUCCESS);
// Assert all jobs succeeded
for (auto& future : futures) {
ASSERT_FINISHES_OK(future);
}
} else {
FileSystemDatasetWriteOptions write_options;
Expand All @@ -158,7 +177,7 @@ class DatasetEncryptionTestBase : public testing::TestWithParam<EncryptionTestPa

virtual void PrepareTableAndPartitioning() = 0;

Result<std::shared_ptr<Dataset>> CreateDataset(
Result<std::shared_ptr<Dataset>> OpenDataset(
std::string_view base_dir, const std::shared_ptr<ParquetFileFormat>& file_format) {
// Get FileInfo objects for all files under the base directory
fs::FileSelector selector;
Expand Down Expand Up @@ -193,66 +212,75 @@ class DatasetEncryptionTestBase : public testing::TestWithParam<EncryptionTestPa
auto file_format = std::make_shared<ParquetFileFormat>();
file_format->default_fragment_scan_options = std::move(parquet_scan_options);

ASSERT_OK_AND_ASSIGN(auto expected_table, table_->CombineChunks());

if (GetParam().concurrently) {
// Create the dataset
ASSERT_OK_AND_ASSIGN(auto dataset, CreateDataset("thread-1", file_format));
ASSERT_OK_AND_ASSIGN(auto dataset, OpenDataset("thread-1", file_format));

// have a notable number of threads to exhibit multi-threading issues
// Have a notable number of threads to exhibit multi-threading issues
ASSERT_OK_AND_ASSIGN(auto pool, arrow::internal::ThreadPool::Make(16));
std::vector<Future<std::shared_ptr<Table>>> threads;
std::vector<Future<std::shared_ptr<Table>>> futures;

// Read dataset above multiple times concurrently to see that is thread-safe.
for (size_t i = 0; i < 100; ++i) {
threads.push_back(
DeferNotOk(pool->Submit(DatasetEncryptionTestBase::read, dataset)));
for (int i = 0; i < kConcurrentIterations; ++i) {
futures.push_back(DeferNotOk(pool->Submit(ReadDataset, dataset)));
}
pool->WaitForIdle();

// assert correctness of jobs
for (auto& thread : threads) {
ASSERT_OK_AND_ASSIGN(auto read_table, thread.result());
AssertTablesEqual(*read_table, *expected_table);
// Assert correctness of jobs
for (auto& future : futures) {
ASSERT_OK_AND_ASSIGN(auto read_table, future.result());
CheckDatasetResults(read_table);
}

// finally check datasets written by all other threads are as expected
for (size_t i = 2; i <= 100; ++i) {
// Finally check datasets written by all other threads are as expected
for (int i = 2; i <= kConcurrentIterations; ++i) {
ASSERT_OK_AND_ASSIGN(dataset,
CreateDataset("thread-" + std::to_string(i), file_format));
ASSERT_OK_AND_ASSIGN(auto read_table, DatasetEncryptionTestBase::read(dataset));
AssertTablesEqual(*read_table, *expected_table);
OpenDataset("thread-" + std::to_string(i), file_format));
ASSERT_OK_AND_ASSIGN(auto read_table, ReadDataset(dataset));
CheckDatasetResults(read_table);
}
} else {
// Create the dataset
ASSERT_OK_AND_ASSIGN(auto dataset, CreateDataset(kBaseDir, file_format));
ASSERT_OK_AND_ASSIGN(auto dataset, OpenDataset(kBaseDir, file_format));

// Reuse the dataset above to scan it twice to make sure decryption works correctly.
for (size_t i = 0; i < 2; ++i) {
ASSERT_OK_AND_ASSIGN(auto read_table, read(dataset));
AssertTablesEqual(*read_table, *expected_table);
for (int i = 0; i < 2; ++i) {
ASSERT_OK_AND_ASSIGN(auto read_table, ReadDataset(dataset));
CheckDatasetResults(read_table);
}
}
}

static Result<std::shared_ptr<Table>> read(const std::shared_ptr<Dataset>& dataset) {
static Result<std::shared_ptr<Table>> ReadDataset(
const std::shared_ptr<Dataset>& dataset) {
// Read dataset into table
ARROW_ASSIGN_OR_RAISE(auto scanner_builder, dataset->NewScan());
ARROW_ASSIGN_OR_RAISE(auto scanner, scanner_builder->Finish());
ARROW_EXPECT_OK(scanner_builder->UseThreads(GetParam().concurrently));
ARROW_ASSIGN_OR_RAISE(auto read_table, scanner->ToTable());
return scanner->ToTable();
}

void CheckDatasetResults(const std::shared_ptr<Table>& table) {
ASSERT_OK(table->ValidateFull());
// Make results comparable despite ordering and chunking differences
ASSERT_OK_AND_ASSIGN(auto combined_table, table->CombineChunks());
ASSERT_OK_AND_ASSIGN(auto sorted_table, SortTable(combined_table));
AssertTablesEqual(*sorted_table, *expected_table_);
}

// Verify the data was read correctly
ARROW_ASSIGN_OR_RAISE(auto combined_table, read_table->CombineChunks());
// Validate the table
RETURN_NOT_OK(combined_table->ValidateFull());
return combined_table;
// Sort table for comparability of dataset read results, which may be unordered.
// This relies on column "a" having statistically unique values.
Result<std::shared_ptr<Table>> SortTable(const std::shared_ptr<Table>& table) {
compute::SortOptions options({compute::SortKey("a")});
ARROW_ASSIGN_OR_RAISE(auto indices, compute::SortIndices(table, options));
ARROW_ASSIGN_OR_RAISE(auto sorted, compute::Take(table, indices));
EXPECT_EQ(sorted.kind(), Datum::TABLE);
return sorted.table();
}

protected:
std::string base_dir_ = GetParam().concurrently ? "thread-1" : std::string(kBaseDir);
std::shared_ptr<fs::FileSystem> file_system_;
std::shared_ptr<Table> table_;
std::shared_ptr<Table> table_, expected_table_;
std::shared_ptr<Partitioning> partitioning_;
std::shared_ptr<parquet::encryption::CryptoFactory> crypto_factory_;
std::shared_ptr<parquet::encryption::KmsConnectionConfig> kms_connection_config_;
Expand Down Expand Up @@ -325,12 +353,8 @@ TEST_P(DatasetEncryptionTest, ReadSingleFile) {
ASSERT_EQ(checked_pointer_cast<Int64Array>(table->column(2)->chunk(0))->GetView(0), 1);
}

INSTANTIATE_TEST_SUITE_P(DatasetEncryptionTest, DatasetEncryptionTest,
::testing::Values(EncryptionTestParam{false, false},
EncryptionTestParam{true, false}));
INSTANTIATE_TEST_SUITE_P(DatasetEncryptionTestThreaded, DatasetEncryptionTest,
::testing::Values(EncryptionTestParam{false, true},
EncryptionTestParam{true, true}));
INSTANTIATE_TEST_SUITE_P(DatasetEncryptionTest, DatasetEncryptionTest, kAllParamValues,
PrintParam);

// GH-39444: This test covers the case where parquet dataset scanner crashes when
// processing encrypted datasets over 2^15 rows in multi-threaded mode.
Expand All @@ -341,7 +365,7 @@ class LargeRowCountEncryptionTest : public DatasetEncryptionTestBase {
// Specifically chosen to be greater than batch size for triggering prefetch.
constexpr int kRowCount = 32769;
// Number of batches
constexpr int kBatchCount = 10;
constexpr int kBatchCount = 5;

// Create multiple random floating-point arrays with large number of rows.
arrow::random::RandomArrayGenerator rand_gen(0);
Expand All @@ -364,11 +388,7 @@ TEST_P(LargeRowCountEncryptionTest, ReadEncryptLargeRowCount) {
}

INSTANTIATE_TEST_SUITE_P(LargeRowCountEncryptionTest, LargeRowCountEncryptionTest,
::testing::Values(EncryptionTestParam{false, false},
EncryptionTestParam{true, false}));
INSTANTIATE_TEST_SUITE_P(LargeRowCountEncryptionTestThreaded, LargeRowCountEncryptionTest,
::testing::Values(EncryptionTestParam{false, true},
EncryptionTestParam{true, true}));
kAllParamValues, PrintParam);

} // namespace dataset
} // namespace arrow
4 changes: 1 addition & 3 deletions cpp/src/parquet/column_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ class SerializedPageReader : public PageReader {

// The CryptoContext used by this PageReader.
CryptoContext crypto_ctx_;
// This PageReader has its own copy of crypto_ctx_->meta_decryptor and
// crypto_ctx_->data_decryptor in order to be thread-safe. Do not mutate (update) the
// instances of crypto_ctx_.
// This PageReader has its own Decryptor instances in order to be thread-safe.
std::shared_ptr<Decryptor> meta_decryptor_;
std::shared_ptr<Decryptor> data_decryptor_;

Expand Down
30 changes: 17 additions & 13 deletions cpp/src/parquet/encryption/encryption_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ class AesCryptoContext {
using CipherContext = std::unique_ptr<EVP_CIPHER_CTX, decltype(&DeleteCipherContext)>;

static CipherContext NewCipherContext() {
return CipherContext(EVP_CIPHER_CTX_new(), DeleteCipherContext);
auto ctx = CipherContext(EVP_CIPHER_CTX_new(), DeleteCipherContext);
if (!ctx) {
throw ParquetException("Couldn't init cipher context");
}
return ctx;
}

int32_t aes_mode_;
Expand Down Expand Up @@ -124,7 +128,7 @@ class AesEncryptor::AesEncryptorImpl : public AesCryptoContext {
}

private:
[[nodiscard]] CipherContext NewCipherContext() const;
[[nodiscard]] CipherContext MakeCipherContext() const;

int32_t GcmEncrypt(span<const uint8_t> plaintext, span<const uint8_t> key,
span<const uint8_t> nonce, span<const uint8_t> aad,
Expand All @@ -139,9 +143,9 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id,
bool write_length)
: AesCryptoContext(alg_id, key_len, metadata, write_length) {}

AesCryptoContext::CipherContext AesEncryptor::AesEncryptorImpl::NewCipherContext() const {
auto ctx = AesCryptoContext::NewCipherContext();
if (!ctx) throw ParquetException("Couldn't init cipher context");
AesCryptoContext::CipherContext AesEncryptor::AesEncryptorImpl::MakeCipherContext()
const {
auto ctx = NewCipherContext();
if (kGcmMode == aes_mode_) {
// Init AES-GCM with specified key length
if (16 == key_length_) {
Expand Down Expand Up @@ -232,7 +236,7 @@ int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span<const uint8_t> plaintext
throw ParquetException(ss.str());
}

auto ctx = NewCipherContext();
auto ctx = MakeCipherContext();

// Setting key and IV (nonce)
if (1 != EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), nonce.data())) {
Expand Down Expand Up @@ -316,7 +320,7 @@ int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span<const uint8_t> plaintext
std::copy(nonce.begin(), nonce.begin() + kNonceLength, iv.begin());
iv[kCtrIvLength - 1] = 1;

auto ctx = NewCipherContext();
auto ctx = MakeCipherContext();

// Setting key and IV
if (1 != EVP_EncryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), iv.data())) {
Expand Down Expand Up @@ -420,7 +424,7 @@ class AesDecryptor::AesDecryptorImpl : AesCryptoContext {
}

private:
[[nodiscard]] CipherContext NewCipherContext() const;
[[nodiscard]] CipherContext MakeCipherContext() const;

/// Get the actual ciphertext length, inclusive of the length buffer length,
/// and validate that the provided buffer size is large enough.
Expand All @@ -445,9 +449,9 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id,
bool contains_length)
: AesCryptoContext(alg_id, key_len, metadata, contains_length) {}

AesCryptoContext::CipherContext AesDecryptor::AesDecryptorImpl::NewCipherContext() const {
auto ctx = AesCryptoContext::NewCipherContext();
if (!ctx) throw ParquetException("Couldn't init cipher context");
AesCryptoContext::CipherContext AesDecryptor::AesDecryptorImpl::MakeCipherContext()
const {
auto ctx = NewCipherContext();
if (kGcmMode == aes_mode_) {
// Init AES-GCM with specified key length
if (16 == key_length_) {
Expand Down Expand Up @@ -589,7 +593,7 @@ int32_t AesDecryptor::AesDecryptorImpl::GcmDecrypt(span<const uint8_t> ciphertex
std::copy(ciphertext.begin() + ciphertext_len - kGcmTagLength,
ciphertext.begin() + ciphertext_len, tag.begin());

auto ctx = NewCipherContext();
auto ctx = MakeCipherContext();

// Setting key and IV
if (1 != EVP_DecryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), nonce.data())) {
Expand Down Expand Up @@ -665,7 +669,7 @@ int32_t AesDecryptor::AesDecryptorImpl::CtrDecrypt(span<const uint8_t> ciphertex
// is set to 1.
iv[kCtrIvLength - 1] = 1;

auto ctx = NewCipherContext();
auto ctx = MakeCipherContext();

// Setting key and IV
if (1 != EVP_DecryptInit_ex(ctx.get(), nullptr, nullptr, key.data(), iv.data())) {
Expand Down

0 comments on commit b3a7585

Please sign in to comment.