Skip to content

Commit 49e54de

Browse files
committed
test: custom endpoints unit tests and integration tests
1 parent 75aacfc commit 49e54de

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+998
-281
lines changed

.github/workflows/failover.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name: Failover Unit Tests
22

33
on:
4+
workflow_dispatch:
45
push:
56
branches:
67
- main
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Custom Endpoint Support
2+
3+
The Custom Endpoint support allows client application to use the driver with RDS custom endpoints. When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover.
4+
5+
## How to use the Driver with Custom Endpoint
6+
7+
### Enabling the Custom Endpoint Feature
8+
9+
1. If needed, create a custom endpoint using the AWS RDS Console:
10+
- If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html).
11+
2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support.
12+
3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict-reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader-or-writer`.
13+
4. Specify parameters that are required or specific to your case.
14+
15+
### Custom Endpoint Plugin Parameters
16+
17+
| Parameter | Value | Required | Description | Default Value | Example Value |
18+
| ------------------------------------------ | :----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- |
19+
| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` |
20+
| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` |
21+
| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` |
22+
| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` |
23+
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` |
24+
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` |

driver/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
7373
connect.cc
7474
connection_handler.cc
7575
connection_proxy.cc
76-
custom_endpoint_proxy.cc
7776
custom_endpoint_info.cc
7877
custom_endpoint_monitor.cc
78+
custom_endpoint_proxy.cc
7979
cursor.cc
8080
desc.cc
8181
dll.cc
@@ -148,9 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
148148
cluster_topology_info.h
149149
connection_handler.h
150150
connection_proxy.h
151-
custom_endpoint_proxy.h
152151
custom_endpoint_info.h
153152
custom_endpoint_monitor.h
153+
custom_endpoint_proxy.h
154154
driver.h
155155
efm_proxy.h
156156
error.h

driver/cache_map.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
// along with this program. If not, see
2828
// http://www.gnu.org/licenses/gpl-2.0.html.
2929

30-
#ifndef __CACHE_MAP__
31-
#define __CACHE_MAP__
30+
#ifndef __CACHE_MAP_H__
31+
#define __CACHE_MAP_H__
3232

3333
#include <atomic>
3434
#include <chrono>

driver/cluster_topology_info.cc

+21
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "cluster_topology_info.h"
3131

3232
#include <stdexcept>
33+
#include <algorithm>
3334

3435
/**
3536
Initialize and return random number.
@@ -75,6 +76,19 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr<HOST_INFO> host_info) {
7576
update_time();
7677
}
7778

79+
void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr<HOST_INFO> host_info) {
80+
auto position = std::find(writers.begin(), writers.end(), host_info);
81+
if (position != writers.end()) {
82+
writers.erase(position);
83+
}
84+
85+
position = std::find(readers.begin(), readers.end(), host_info);
86+
if (position != readers.end()) {
87+
readers.erase(position);
88+
}
89+
update_time();
90+
}
91+
7892
size_t CLUSTER_TOPOLOGY_INFO::total_hosts() {
7993
return writers.size() + readers.size();
8094
}
@@ -136,6 +150,13 @@ std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_writers() {
136150
return writers;
137151
}
138152

153+
std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_instances() {
154+
std::vector instances(writers);
155+
instances.insert(instances.end(), readers.begin(), readers.end());
156+
157+
return instances;
158+
}
159+
139160
std::shared_ptr<HOST_INFO> CLUSTER_TOPOLOGY_INFO::get_last_used_reader() {
140161
return last_used_reader;
141162
}

driver/cluster_topology_info.h

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO {
4646
virtual ~CLUSTER_TOPOLOGY_INFO();
4747

4848
void add_host(std::shared_ptr<HOST_INFO> host_info);
49+
void remove_host(std::shared_ptr<HOST_INFO> host_info);
4950
size_t total_hosts();
5051
size_t num_readers(); // return number of readers in the cluster
5152
std::time_t time_last_updated();
@@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO {
5859
std::shared_ptr<HOST_INFO> get_reader(int i);
5960
std::vector<std::shared_ptr<HOST_INFO>> get_writers();
6061
std::vector<std::shared_ptr<HOST_INFO>> get_readers();
62+
std::vector<std::shared_ptr<HOST_INFO>> get_instances();
6163

6264
private:
6365
int current_reader = -1;

driver/connection_handler.cc

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ CONNECTION_PROXY* CONNECTION_HANDLER::connect(std::shared_ptr<HOST_INFO> host_in
8383
}
8484

8585
my_SQLFreeConnect(dbc_clone);
86+
new_connection->set_dbc(dbc);
8687

8788
return new_connection;
8889
}

driver/connection_proxy.h

+2
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ class CONNECTION_PROXY {
177177

178178
void set_custom_error_message(const char* error_message);
179179

180+
void set_dbc(DBC* dbc) { this->dbc = dbc; };
181+
180182
protected:
181183
DBC* dbc = nullptr;
182184
DataSource* ds = nullptr;

driver/custom_endpoint_info.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@
3333
#include <aws/rds/model/DBClusterEndpoint.h>
3434

3535
#include <set>
36-
#include <sstream>
37-
#include <utility>
38-
39-
#include "MYODBC_MYSQL.h"
36+
#include "stringutil.h"
4037
#include "mylog.h"
4138

4239
/**

driver/custom_endpoint_monitor.cc

+79-51
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@
2727
// along with this program. If not, see
2828
// http://www.gnu.org/licenses/gpl-2.0.html.
2929

30-
#include "custom_endpoint_monitor.h"
31-
3230
#include <aws/core/auth/AWSCredentialsProviderChain.h>
3331
#include <aws/rds/model/DBClusterEndpoint.h>
3432
#include <aws/rds/model/DescribeDBClusterEndpointsRequest.h>
3533
#include <aws/rds/model/Filter.h>
34+
#include <utility>
3635
#include <vector>
3736

3837
#include "allowed_and_blocked_hosts.h"
3938
#include "aws_sdk_helper.h"
39+
#include "custom_endpoint_monitor.h"
4040
#include "driver.h"
41-
#include "monitor_service.h"
4241
#include "mylog.h"
4342

4443
namespace {
@@ -47,55 +46,81 @@ AWS_SDK_HELPER SDK_HELPER;
4746

4847
CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache;
4948

50-
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<HOST_INFO>& custom_endpoint_host_info,
49+
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
50+
const std::string& custom_endpoint_host,
5151
const std::string& endpoint_identifier, const std::string& region,
52-
DataSource* ds, bool enable_logging)
53-
: custom_endpoint_host_info(custom_endpoint_host_info),
54-
endpoint_identifier(endpoint_identifier),
55-
region(region),
56-
enable_logging(enable_logging) {
52+
long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
53+
bool enable_logging)
54+
: topology_service(topology_service),
55+
custom_endpoint_host(custom_endpoint_host),
56+
endpoint_identifier(endpoint_identifier),
57+
region(region),
58+
refresh_rate_nanos(refresh_rate_nanos),
59+
thread_pool(thread_pool),
60+
enable_logging(enable_logging) {
5761
if (enable_logging) {
5862
this->logger = init_log_file();
5963
}
6064

61-
++SDK_HELPER;
65+
this->run();
66+
}
6267

63-
Aws::RDS::RDSClientConfiguration client_config;
64-
if (!region.empty()) {
65-
client_config.region = region;
68+
#ifdef UNIT_TEST_BUILD
69+
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
70+
const std::string& custom_endpoint_host,
71+
const std::string& endpoint_identifier, const std::string& region,
72+
long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
73+
bool enable_logging, std::shared_ptr<Aws::RDS::RDSClient> client)
74+
: topology_service(topology_service),
75+
custom_endpoint_host(custom_endpoint_host),
76+
endpoint_identifier(endpoint_identifier),
77+
region(region),
78+
refresh_rate_nanos(refresh_rate_nanos),
79+
thread_pool(thread_pool),
80+
enable_logging(enable_logging) {
81+
if (enable_logging) {
82+
this->logger = init_log_file();
6683
}
67-
68-
this->rds_client = std::make_shared<Aws::RDS::RDSClient>(
69-
Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config);
70-
7184
this->run();
7285
}
86+
#endif
7387

7488
bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; }
7589

7690
bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const {
7791
auto default_val = std::shared_ptr<CUSTOM_ENDPOINT_INFO>(nullptr);
78-
return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val;
92+
return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val;
7993
}
8094

8195
void CUSTOM_ENDPOINT_MONITOR::run() {
82-
this->thread_pool.resize(1);
83-
this->thread_pool.push([=](int id) {
84-
MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'",
85-
this->custom_endpoint_host_info->get_host().c_str());
96+
MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
97+
if (thread_pool.size() == 1) {
98+
// Each monitor should only have 1 thread.
99+
return;
100+
}
101+
thread_pool.resize(1);
102+
thread_pool.push([=](int id) {
103+
++SDK_HELPER;
86104

87-
try {
88-
while (!this->should_stop.load()) {
89-
const std::chrono::time_point start = std::chrono::steady_clock::now();
90-
Aws::RDS::Model::Filter filter;
91-
filter.SetName("db-cluster-endpoint-type");
92-
filter.SetValues({"custom"});
105+
Aws::RDS::RDSClientConfiguration client_config;
106+
if (!region.empty()) {
107+
client_config.region = region;
108+
}
93109

94-
Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
95-
request.SetDBClusterIdentifier(this->endpoint_identifier);
96-
request.SetFilters({filter});
97-
const auto response = this->rds_client->DescribeDBClusterEndpoints(request);
110+
const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(),
111+
client_config);
112+
Aws::RDS::Model::Filter filter;
113+
filter.SetName("db-cluster-endpoint-type");
114+
filter.AddValues("custom");
98115

116+
Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
117+
request.SetDBClusterEndpointIdentifier(this->endpoint_identifier);
118+
// TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null.
119+
// request.AddFilters(filter);
120+
try {
121+
while (!should_stop) {
122+
const std::chrono::time_point start = std::chrono::steady_clock::now();
123+
const auto response = rds_client.DescribeDBClusterEndpoints(request);
99124
const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints();
100125
if (custom_endpoints.size() != 1) {
101126
MYLOG_TRACE(this->logger, 0,
@@ -108,48 +133,52 @@ void CUSTOM_ENDPOINT_MONITOR::run() {
108133
continue;
109134
}
110135
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
111-
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
136+
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
112137
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
113-
custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr);
138+
custom_endpoint_cache.get(this->custom_endpoint_host, nullptr);
114139

115140
if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
116141
const long long elapsed_time =
117-
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
142+
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
118143
std::this_thread::sleep_for(
119-
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
144+
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
120145
continue;
121146
}
122147

123148
MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
124-
custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str());
149+
custom_endpoint_host.c_str(), endpoint_info->to_string().c_str());
125150

126151
// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
127152
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
128153
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
129154
allowed_and_blocked_hosts =
130-
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
155+
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
131156
} else {
132-
allowed_and_blocked_hosts =
133-
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(std::set<std::string>(), endpoint_info->get_excluded_members());
157+
allowed_and_blocked_hosts = std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(
158+
std::set<std::string>(), endpoint_info->get_excluded_members());
134159
}
135160

136-
custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info,
137-
CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
161+
this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts);
162+
custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
138163
const long long elapsed_time =
139-
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
164+
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
140165
std::this_thread::sleep_for(
141-
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
166+
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
142167
}
143168

169+
--SDK_HELPER;
144170
} catch (const std::exception& e) {
145171
// Log and continue monitoring.
172+
--SDK_HELPER;
146173
MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what());
147174
}
175+
176+
should_stop = true;
148177
});
149178
}
150179

151180
std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
152-
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
181+
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
153182
if (custom_endpoints.empty()) {
154183
return "<no endpoints>";
155184
}
@@ -168,12 +197,11 @@ std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
168197
}
169198

170199
void CUSTOM_ENDPOINT_MONITOR::stop() {
171-
this->should_stop.store(true);
172-
this->thread_pool.stop(true);
173-
this->thread_pool.resize(0);
174-
custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host());
175-
--SDK_HELPER;
176-
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str());
200+
should_stop = true;
201+
thread_pool.stop(true);
202+
thread_pool.resize(0);
203+
custom_endpoint_cache.remove(this->custom_endpoint_host);
204+
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
177205
}
178206

179207
void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); }

0 commit comments

Comments
 (0)