diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index cf690000507f3..d70ddec0fb526 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -851,6 +851,7 @@ NodeConfig::NodeConfig() { NONE_PROP(kNodeIp), NONE_PROP(kNodeInternalAddress), NONE_PROP(kNodeLocation), + NONE_PROP(kNodePrometheusExecutorThreads), }; } @@ -863,6 +864,16 @@ std::string NodeConfig::nodeEnvironment() const { return requiredProperty(kNodeEnvironment); } +int NodeConfig::prometheusExecutorThreads() const { + static constexpr int + kNodePrometheusExecutorThreadsDefault = 2; + auto resultOpt = optionalProperty(kNodePrometheusExecutorThreads); + if (resultOpt.hasValue()) { + return resultOpt.value(); + } + return kNodePrometheusExecutorThreadsDefault; +} + std::string NodeConfig::nodeId() const { auto resultOpt = optionalProperty(kNodeId); if (resultOpt.hasValue()) { diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index df7b896169da4..4ed8b81c81c45 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -978,6 +978,7 @@ class NodeConfig : public ConfigBase { static constexpr std::string_view kNodeInternalAddress{ "node.internal-address"}; static constexpr std::string_view kNodeLocation{"node.location"}; + static constexpr std::string_view kNodePrometheusExecutorThreads{"node.prometheus.executor-threads"}; NodeConfig(); @@ -987,6 +988,8 @@ class NodeConfig : public ConfigBase { std::string nodeEnvironment() const; + int prometheusExecutorThreads() const; + std::string nodeId() const; std::string nodeInternalAddress( diff --git a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp index c119480ae4654..2193d9856baf1 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp +++ b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp @@ -212,6 +212,11 @@ TEST_F(ConfigTest, optionalNodeConfigs) { init(config, {{std::string(NodeConfig::kNodeIp), "127.0.0.1"}}); ASSERT_EQ( config.nodeInternalAddress([]() { return "0.0.0.0"; }), "127.0.0.1"); + + // make sure "node.kNodePrometheusExecutorThreads" works too + init(config, {{std::string(NodeConfig::kNodePrometheusExecutorThreads), "4"}}); + ASSERT_EQ( + config.prometheusExecutorThreads(), 4); } TEST_F(ConfigTest, optionalSystemConfigsWithDefault) { diff --git a/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.cpp b/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.cpp index 6ed3ea7c39594..225bc99975ac6 100644 --- a/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.cpp +++ b/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.cpp @@ -47,8 +47,9 @@ struct PrometheusStatsReporter::PrometheusImpl { }; PrometheusStatsReporter::PrometheusStatsReporter( - const std::map& labels) { - impl_ = std::make_shared(labels); + const std::map& labels, int numThreads) + : executor_(std::make_shared(numThreads)), + impl_(std::make_shared(labels)) { } void PrometheusStatsReporter::registerMetricExportType( @@ -161,36 +162,38 @@ void PrometheusStatsReporter::addMetricValue( void PrometheusStatsReporter::addMetricValue(const char* key, size_t value) const { - auto metricIterator = registeredMetricsMap_.find(key); - if (metricIterator == registeredMetricsMap_.end()) { - VLOG(1) << "addMetricValue called for unregistered metric " << key; - return; - } - auto statsInfo = metricIterator->second; - switch (statsInfo.statType) { - case velox::StatType::COUNT: { - auto* counter = - reinterpret_cast<::prometheus::Counter*>(statsInfo.metricPtr); - counter->Increment(static_cast(value)); - break; - } - case velox::StatType::SUM: { - auto* gauge = reinterpret_cast<::prometheus::Gauge*>(statsInfo.metricPtr); - gauge->Increment(static_cast(value)); - break; - } - case velox::StatType::AVG: - case velox::StatType::RATE: { - // Overrides the existing state. - auto* gauge = reinterpret_cast<::prometheus::Gauge*>(statsInfo.metricPtr); - gauge->Set(static_cast(value)); - break; + executor_->add([this, key, value]() { + auto metricIterator = registeredMetricsMap_.find(key); + if (metricIterator == registeredMetricsMap_.end()) { + VLOG(1) << "addMetricValue called for unregistered metric " << key; + return; } - default: - VELOX_UNSUPPORTED( - "Unsupported metric type {}", - velox::statTypeString(statsInfo.statType)); - }; + auto statsInfo = metricIterator->second; + switch (statsInfo.statType) { + case velox::StatType::COUNT: { + auto* counter = + reinterpret_cast<::prometheus::Counter*>(statsInfo.metricPtr); + counter->Increment(static_cast(value)); + break; + } + case velox::StatType::SUM: { + auto* gauge = reinterpret_cast<::prometheus::Gauge*>(statsInfo.metricPtr); + gauge->Increment(static_cast(value)); + break; + } + case velox::StatType::AVG: + case velox::StatType::RATE: { + // Overrides the existing state. + auto* gauge = reinterpret_cast<::prometheus::Gauge*>(statsInfo.metricPtr); + gauge->Set(static_cast(value)); + break; + } + default: + VELOX_UNSUPPORTED( + "Unsupported metric type {}", + velox::statTypeString(statsInfo.statType)); + }; + }); } void PrometheusStatsReporter::addMetricValue( @@ -208,22 +211,24 @@ void PrometheusStatsReporter::addHistogramMetricValue( void PrometheusStatsReporter::addHistogramMetricValue( const char* key, size_t value) const { - auto metricIterator = registeredMetricsMap_.find(key); - if (metricIterator == registeredMetricsMap_.end()) { - VLOG(1) << "addMetricValue for unregistered metric " << key; - return; - } - auto histogram = reinterpret_cast<::prometheus::Histogram*>( - metricIterator->second.metricPtr); - histogram->Observe(value); - - std::string summaryKey = std::string(key).append(kSummarySuffix); - metricIterator = registeredMetricsMap_.find(summaryKey); - if (metricIterator != registeredMetricsMap_.end()) { - auto summary = reinterpret_cast<::prometheus::Summary*>( + executor_->add([this, key, value]() { + auto metricIterator = registeredMetricsMap_.find(key); + if (metricIterator == registeredMetricsMap_.end()) { + VLOG(1) << "addMetricValue for unregistered metric " << key; + return; + } + auto histogram = reinterpret_cast<::prometheus::Histogram*>( metricIterator->second.metricPtr); - summary->Observe(value); - } + histogram->Observe(value); + + std::string summaryKey = std::string(key).append(kSummarySuffix); + metricIterator = registeredMetricsMap_.find(summaryKey); + if (metricIterator != registeredMetricsMap_.end()) { + auto summary = reinterpret_cast<::prometheus::Summary*>( + metricIterator->second.metricPtr); + summary->Observe(value); + } + }); } void PrometheusStatsReporter::addHistogramMetricValue( @@ -241,4 +246,8 @@ std::string PrometheusStatsReporter::fetchMetrics() { return serializer.Serialize(impl_->registry->Collect()); } +void PrometheusStatsReporter::waitForCompletion() const { + executor_->join(); +} + } // namespace facebook::presto::prometheus diff --git a/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.h b/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.h index 5253a47ff5b61..2804b896165a0 100644 --- a/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.h +++ b/presto-native-execution/presto_cpp/main/runtime-metrics/PrometheusStatsReporter.h @@ -16,6 +16,7 @@ #include "velox/common/base/Exceptions.h" #include "velox/common/base/GTestMacros.h" #include "velox/common/base/StatsReporter.h" +#include namespace facebook::presto::prometheus { @@ -38,8 +39,13 @@ class PrometheusStatsReporter : public facebook::velox::BaseStatsReporter { class PrometheusImpl; public: + /** + * @brief Constructor with optional thread count + * @param labels Labels for metrics. + * @param numThreads Number of threads in the executor + */ explicit PrometheusStatsReporter( - const std::map& labels); + const std::map& labels, int numThreads); void registerMetricExportType(const char* key, velox::StatType) const override; @@ -77,6 +83,12 @@ class PrometheusStatsReporter : public facebook::velox::BaseStatsReporter { std::string fetchMetrics() override; + /** + * Waits for all pending metric updates to complete. + * This is only used in tests to ensure correct timing. + */ + void waitForCompletion() const; + static std::unique_ptr createPrometheusReporter() { auto nodeConfig = NodeConfig::instance(); const std::string cluster = nodeConfig->nodeEnvironment(); @@ -84,14 +96,17 @@ class PrometheusStatsReporter : public facebook::velox::BaseStatsReporter { const std::string worker = !hostName ? "" : hostName; std::map labels{ {"cluster", cluster}, {"worker", worker}}; - return std::make_unique(labels); + return std::make_unique(labels, nodeConfig->prometheusExecutorThreads()); } + // Visible for testing + mutable std::unordered_map registeredMetricsMap_; + private: + std::shared_ptr executor_; std::shared_ptr impl_; // A map of labels assigned to each metric which helps in filtering at client // end. - mutable std::unordered_map registeredMetricsMap_; VELOX_FRIEND_TEST(PrometheusReporterTest, testCountAndGauge); VELOX_FRIEND_TEST(PrometheusReporterTest, testHistogramSummary); }; diff --git a/presto-native-execution/presto_cpp/main/runtime-metrics/tests/PrometheusReporterTest.cpp b/presto-native-execution/presto_cpp/main/runtime-metrics/tests/PrometheusReporterTest.cpp index c6b6d635d1a97..bca47b348c29d 100644 --- a/presto-native-execution/presto_cpp/main/runtime-metrics/tests/PrometheusReporterTest.cpp +++ b/presto-native-execution/presto_cpp/main/runtime-metrics/tests/PrometheusReporterTest.cpp @@ -20,7 +20,8 @@ namespace facebook::presto::prometheus { class PrometheusReporterTest : public testing::Test { public: void SetUp() override { - reporter = std::make_shared(testLabels); + reporter = std::make_shared(testLabels, 1); + multiThreadedReporter = std::make_shared(testLabels, 2); } void verifySerializedResult( @@ -40,6 +41,43 @@ class PrometheusReporterTest : public testing::Test { const std::string labelsSerialized = R"(cluster="test_cluster",worker="test_worker_pod")"; std::shared_ptr reporter; + std::shared_ptr multiThreadedReporter; +}; + +TEST_F(PrometheusReporterTest, testConcurrentReporting) { + multiThreadedReporter->registerMetricExportType( + "test.key1", facebook::velox::StatType::COUNT); + multiThreadedReporter->registerMetricExportType( + "test.key3", facebook::velox::StatType::SUM); + EXPECT_EQ( + facebook::velox::StatType::COUNT, + multiThreadedReporter->registeredMetricsMap_.find("test.key1")->second.statType); + EXPECT_EQ( + facebook::velox::StatType::SUM, + multiThreadedReporter->registeredMetricsMap_.find("test.key3")->second.statType); + + std::vector testData = {10, 12, 14}; + for (auto i : testData) { + multiThreadedReporter->addMetricValue("test.key1", i); + multiThreadedReporter->addMetricValue("test.key3", i + 2000); + } + + // Uses default value of 1 for second parameter. + multiThreadedReporter->addMetricValue("test.key1"); + multiThreadedReporter->addMetricValue("test.key3"); + + // Wait for all async updates to finish before validation + multiThreadedReporter->waitForCompletion(); + + auto fullSerializedResult = multiThreadedReporter->fetchMetrics(); + + std::vector expected = { + "# TYPE test_key1 counter", + "test_key1{" + labelsSerialized + "} 37", + "# TYPE test_key3 gauge", + "test_key3{" + labelsSerialized + "} 6037"}; + + verifySerializedResult(fullSerializedResult, expected); }; TEST_F(PrometheusReporterTest, testCountAndGauge) { @@ -75,6 +113,7 @@ TEST_F(PrometheusReporterTest, testCountAndGauge) { // Uses default value of 1 for second parameter. reporter->addMetricValue("test.key1"); reporter->addMetricValue("test.key3"); + reporter->waitForCompletion(); auto fullSerializedResult = reporter->fetchMetrics(); @@ -114,6 +153,7 @@ TEST_F(PrometheusReporterTest, testHistogramSummary) { } } reporter->addHistogramMetricValue(histogramKey, 10); + reporter->waitForCompletion(); auto fullSerializedResult = reporter->fetchMetrics(); std::replace(histSummaryKey.begin(), histSummaryKey.end(), '.', '_'); std::replace(histogramKey.begin(), histogramKey.end(), '.', '_');