Skip to content

Commit 6a33b42

Browse files
committed
feat: custom endpoint support
1 parent 298a49b commit 6a33b42

20 files changed

+1085
-66
lines changed

driver/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
6262
auth_util.cc
6363
aws_sdk_helper.cc
6464
base_metrics_holder.cc
65+
cache_map.cc
6566
catalog.cc
6667
catalog_no_i_s.cc
6768
cluster_topology_info.cc
@@ -72,6 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
7273
connect.cc
7374
connection_handler.cc
7475
connection_proxy.cc
76+
custom_endpoint_proxy.cc
77+
custom_endpoint_info.cc
78+
custom_endpoint_monitor.cc
7579
cursor.cc
7680
desc.cc
7781
dll.cc
@@ -131,9 +135,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
131135
CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY)
132136
SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc
133137
adfs_proxy.h
138+
allowed_and_blocked_hosts.h
134139
auth_util.h
135140
aws_sdk_helper.h
136141
base_metrics_holder.h
142+
cache_map.h
137143
catalog.h
138144
cluster_aware_hit_metrics_holder.h
139145
cluster_aware_metrics_container.h
@@ -142,6 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
142148
cluster_topology_info.h
143149
connection_handler.h
144150
connection_proxy.h
151+
custom_endpoint_proxy.h
152+
custom_endpoint_info.h
153+
custom_endpoint_monitor.h
145154
driver.h
146155
efm_proxy.h
147156
error.h

driver/allowed_and_blocked_hosts.h

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// This program is free software; you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License, version 2.0
5+
// (GPLv2), as published by the Free Software Foundation, with the
6+
// following additional permissions:
7+
//
8+
// This program is distributed with certain software that is licensed
9+
// under separate terms, as designated in a particular file or component
10+
// or in the license documentation. Without limiting your rights under
11+
// the GPLv2, the authors of this program hereby grant you an additional
12+
// permission to link the program and your derivative works with the
13+
// separately licensed software that they have included with the program.
14+
//
15+
// Without limiting the foregoing grant of rights under the GPLv2 and
16+
// additional permission as to separately licensed software, this
17+
// program is also subject to the Universal FOSS Exception, version 1.0,
18+
// a copy of which can be found along with its FAQ at
19+
// http://oss.oracle.com/licenses/universal-foss-exception.
20+
//
21+
// This program is distributed in the hope that it will be useful, but
22+
// WITHOUT ANY WARRANTY; without even the implied warranty of
23+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
24+
// See the GNU General Public License, version 2.0, for more details.
25+
//
26+
// You should have received a copy of the GNU General Public License
27+
// along with this program. If not, see
28+
// http://www.gnu.org/licenses/gpl-2.0.html.
29+
30+
#ifndef __ALLOWED_AND_BLOCKED_HOSTS__
31+
#define __ALLOWED_AND_BLOCKED_HOSTS__
32+
33+
#include <set>
34+
#include <string>
35+
36+
/**
37+
* Represents the allowed and blocked hosts for connections.
38+
*/
39+
class ALLOWED_AND_BLOCKED_HOSTS {
40+
public:
41+
/**
42+
* Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs.
43+
* @param allowed_host_ids The set of allowed host IDs for connections. If null or empty, all host IDs that are not in
44+
* `blocked_host_ids` are allowed.
45+
* @param blocked_host_ids The set of blocked host IDs for connections. If null or empty, all host IDs in
46+
* `allowed_host_ids` are allowed. If `allowed_host_ids` is also null or empty, there
47+
* are no restrictions on which hosts are allowed.
48+
*/
49+
ALLOWED_AND_BLOCKED_HOSTS(const std::set<std::string>& allowed_host_ids,
50+
const std::set<std::string>& blocked_host_ids)
51+
: allowed_host_ids(allowed_host_ids), blocked_host_ids(blocked_host_ids){};
52+
53+
/**
54+
* Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in
55+
* `blocked_host_ids` are allowed.
56+
*
57+
* @return the set of allowed host IDs for connections.
58+
*/
59+
std::set<std::string> get_allowed_host_ids() { return this->allowed_host_ids; };
60+
61+
/**
62+
* Returns the set of blocked host IDs for connections. If null or empty, all host IDs in `allowed_host_ids`
63+
* are allowed. If `allowed_host_ids` is also null or empty, there are no restrictions on which hosts are allowed.
64+
*
65+
* @return the set of blocked host IDs for connections.
66+
*/
67+
std::set<std::string> get_blocked_host_ids() { return this->blocked_host_ids; };
68+
69+
private:
70+
std::set<std::string> allowed_host_ids;
71+
std::set<std::string> blocked_host_ids;
72+
};
73+
74+
#endif

driver/auth_util.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ std::pair<std::string, bool> AUTH_UTIL::get_auth_token(std::unordered_map<std::s
7474
}
7575

7676
std::string auth_token;
77-
const std::string cache_key = this->build_cache_key(host, region, port, user);
77+
const std::string cache_key = build_cache_key(host, region, port, user);
7878
bool using_cached_token = false;
7979

8080
{

driver/cache_map.cc

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// This program is free software; you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License, version 2.0
5+
// (GPLv2), as published by the Free Software Foundation, with the
6+
// following additional permissions:
7+
//
8+
// This program is distributed with certain software that is licensed
9+
// under separate terms, as designated in a particular file or component
10+
// or in the license documentation. Without limiting your rights under
11+
// the GPLv2, the authors of this program hereby grant you an additional
12+
// permission to link the program and your derivative works with the
13+
// separately licensed software that they have included with the program.
14+
//
15+
// Without limiting the foregoing grant of rights under the GPLv2 and
16+
// additional permission as to separately licensed software, this
17+
// program is also subject to the Universal FOSS Exception, version 1.0,
18+
// a copy of which can be found along with its FAQ at
19+
// http://oss.oracle.com/licenses/universal-foss-exception.
20+
//
21+
// This program is distributed in the hope that it will be useful, but
22+
// WITHOUT ANY WARRANTY; without even the implied warranty of
23+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
24+
// See the GNU General Public License, version 2.0, for more details.
25+
//
26+
// You should have received a copy of the GNU General Public License
27+
// along with this program. If not, see
28+
// http://www.gnu.org/licenses/gpl-2.0.html.
29+
30+
#include "cache_map.h"
31+
32+
#include <utility>
33+
34+
#include "custom_endpoint_info.h"
35+
36+
template <class K, class V>
37+
void CACHE_MAP<K, V>::put(K key, V value, long long item_expiration_nanos) {
38+
this->cache[key] = std::make_shared<CACHE_ITEM>(
39+
value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos));
40+
}
41+
42+
template <class K, class V>
43+
V CACHE_MAP<K, V>::get(K key, V default_value) {
44+
if (cache.count(key) > 0 && !cache[key]->is_expired()) {
45+
return this->cache[key]->item;
46+
}
47+
return default_value;
48+
}
49+
50+
template <class K, class V>
51+
V CACHE_MAP<K, V>::get(K key, V default_value, long long item_expiration_nanos) {
52+
if (cache.count(key) == 0 || this->cache[key]->is_expired()) {
53+
this->put(key, std::move(default_value), item_expiration_nanos);
54+
}
55+
return this->cache[key]->item;
56+
}
57+
58+
template <class K, class V>
59+
void CACHE_MAP<K, V>::remove(K key) {
60+
if (this->cache.count(key)) {
61+
this->cache.erase(key);
62+
}
63+
}
64+
65+
template <class K, class V>
66+
int CACHE_MAP<K, V>::size() {
67+
return this->cache.size();
68+
}
69+
70+
template <class K, class V>
71+
void CACHE_MAP<K, V>::clear() {
72+
this->cache.clear();
73+
this->clean_up();
74+
}
75+
76+
template <class K, class V>
77+
void CACHE_MAP<K, V>::clean_up() {
78+
if (std::chrono::steady_clock::now() > this->clean_up_time_nanos.load()) {
79+
this->clean_up_time_nanos =
80+
std::chrono::steady_clock::now() + std::chrono::nanoseconds(this->clean_up_time_interval_nanos);
81+
std::vector<K> keys;
82+
keys.reserve(this->cache.size());
83+
for (auto& [key, cache_item] : this->cache) {
84+
keys.push_back(key);
85+
}
86+
for (const auto& key : keys) {
87+
if (this->cache[key]->is_expired()) {
88+
this->cache.erase(key);
89+
}
90+
}
91+
}
92+
}
93+
94+
template class CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>>;

driver/cache_map.h

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// This program is free software; you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License, version 2.0
5+
// (GPLv2), as published by the Free Software Foundation, with the
6+
// following additional permissions:
7+
//
8+
// This program is distributed with certain software that is licensed
9+
// under separate terms, as designated in a particular file or component
10+
// or in the license documentation. Without limiting your rights under
11+
// the GPLv2, the authors of this program hereby grant you an additional
12+
// permission to link the program and your derivative works with the
13+
// separately licensed software that they have included with the program.
14+
//
15+
// Without limiting the foregoing grant of rights under the GPLv2 and
16+
// additional permission as to separately licensed software, this
17+
// program is also subject to the Universal FOSS Exception, version 1.0,
18+
// a copy of which can be found along with its FAQ at
19+
// http://oss.oracle.com/licenses/universal-foss-exception.
20+
//
21+
// This program is distributed in the hope that it will be useful, but
22+
// WITHOUT ANY WARRANTY; without even the implied warranty of
23+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
24+
// See the GNU General Public License, version 2.0, for more details.
25+
//
26+
// You should have received a copy of the GNU General Public License
27+
// along with this program. If not, see
28+
// http://www.gnu.org/licenses/gpl-2.0.html.
29+
30+
#ifndef __CACHE_MAP__
31+
#define __CACHE_MAP__
32+
33+
#include <chrono>
34+
#include <memory>
35+
#include <unordered_map>
36+
37+
template <class K, class V>
38+
class CACHE_MAP {
39+
public:
40+
class CACHE_ITEM {
41+
public:
42+
CACHE_ITEM() = default;
43+
CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time)
44+
: item(item), expiration_time(expiration_time){};
45+
~CACHE_ITEM() = default;
46+
V item;
47+
48+
bool is_expired() { return std::chrono::steady_clock::now() > this->expiration_time; }
49+
50+
private:
51+
std::chrono::steady_clock::time_point expiration_time;
52+
};
53+
54+
CACHE_MAP() = default;
55+
~CACHE_MAP() = default;
56+
57+
void put(K key, V value, long long item_expiration_nanos);
58+
V get(K key, V default_value);
59+
V get(K key, V default_value, long long item_expiration_nanos);
60+
void remove(K key);
61+
int size();
62+
void clear();
63+
64+
protected:
65+
void clean_up();
66+
const long long clean_up_time_interval_nanos = 60000000000; // 10 minute
67+
std::atomic<std::chrono::steady_clock::time_point> clean_up_time_nanos;
68+
69+
private:
70+
std::unordered_map<K, std::shared_ptr<CACHE_ITEM>> cache;
71+
};
72+
73+
#endif

driver/custom_endpoint_info.cc

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
//
3+
// This program is free software {} you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License, version 2.0
5+
// (GPLv2), as published by the Free Software Foundation, with the
6+
// following additional permissions:
7+
//
8+
// This program is distributed with certain software that is licensed
9+
// under separate terms, as designated in a particular file or component
10+
// or in the license documentation. Without limiting your rights under
11+
// the GPLv2, the authors of this program hereby grant you an additional
12+
// permission to link the program and your derivative works with the
13+
// separately licensed software that they have included with the program.
14+
//
15+
// Without limiting the foregoing grant of rights under the GPLv2 and
16+
// additional permission as to separately licensed software, this
17+
// program is also subject to the Universal FOSS Exception, version 1.0,
18+
// a copy of which can be found along with its FAQ at
19+
// http://oss.oracle.com/licenses/universal-foss-exception.
20+
//
21+
// This program is distributed in the hope that it will be useful, but
22+
// WITHOUT ANY WARRANTY {} without even the implied warranty of
23+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
24+
// See the GNU General Public License, version 2.0, for more details.
25+
//
26+
// You should have received a copy of the GNU General Public License
27+
// along with this program. If not, see
28+
// http://www.gnu.org/licenses/gpl-2.0.html.
29+
30+
#include "custom_endpoint_info.h"
31+
32+
std::shared_ptr<CUSTOM_ENDPOINT_INFO> CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(
33+
const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info) {
34+
std::vector<std::string> members;
35+
MEMBERS_LIST_TYPE members_list_type;
36+
37+
if (response_endpoint_info.StaticMembersHasBeenSet()) {
38+
members = response_endpoint_info.GetStaticMembers();
39+
members_list_type = STATIC_LIST;
40+
} else {
41+
members = response_endpoint_info.GetExcludedMembers();
42+
members_list_type = EXCLUSION_LIST;
43+
}
44+
45+
std::set members_set(members.begin(), members.end());
46+
47+
return std::make_shared<CUSTOM_ENDPOINT_INFO>(
48+
response_endpoint_info.GetDBClusterEndpointIdentifier(), response_endpoint_info.GetDBClusterIdentifier(),
49+
response_endpoint_info.GetEndpoint(),
50+
CUSTOM_ENDPOINT_INFO::get_role_type(response_endpoint_info.GetCustomEndpointType()), members_set,
51+
members_list_type);
52+
}
53+
54+
std::set<std::string> CUSTOM_ENDPOINT_INFO::get_excluded_members() const {
55+
if (this->member_list_type == EXCLUSION_LIST) {
56+
return members;
57+
}
58+
59+
return std::set<std::string>();
60+
}
61+
62+
std::set<std::string> CUSTOM_ENDPOINT_INFO::get_static_members() const {
63+
if (this->member_list_type == STATIC_LIST) {
64+
return members;
65+
}
66+
67+
return std::set<std::string>();
68+
}
69+
70+
bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other) {
71+
return current.endpoint_identifier == other.endpoint_identifier &&
72+
current.cluster_identifier == other.cluster_identifier && current.url == other.url &&
73+
current.role_type == other.role_type &&
74+
current.member_list_type == other.member_list_type;
75+
}
76+
77+
CUSTOM_ENDPOINT_ROLE_TYPE CUSTOM_ENDPOINT_INFO::get_role_type(const Aws::String& role_type) {
78+
auto it = CUSTOM_ENDPOINT_ROLE_TYPE_MAP.find(role_type);
79+
if (it != CUSTOM_ENDPOINT_ROLE_TYPE_MAP.end()) {
80+
return it->second;
81+
}
82+
83+
throw std::invalid_argument("Invalid role type for custom endpoint, this should not have happened.");
84+
}

0 commit comments

Comments
 (0)