Skip to content

Commit 6698fb5

Browse files
feat: custom endpoint plugin
1 parent f0dfc1c commit 6698fb5

39 files changed

+1582
-64
lines changed

common/lib/AllowedAndBlockedHosts.ts

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
export class AllowedAndBlockedHosts {
18+
private readonly allowedHostIds: Set<string>;
19+
private readonly blockedHostIds: Set<string>;
20+
21+
constructor(allowedHostIds: Set<string>, blockedHostIds: Set<string>) {
22+
this.allowedHostIds = allowedHostIds;
23+
this.blockedHostIds = blockedHostIds;
24+
}
25+
26+
getAllowedHostIds() {
27+
return this.allowedHostIds;
28+
}
29+
30+
getBlockedHostIds() {
31+
return this.blockedHostIds;
32+
}
33+
}

common/lib/authentication/iam_authentication_plugin.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,23 @@
1515
*/
1616

1717
import { PluginService } from "../plugin_service";
18-
import { RdsUtils } from "../utils/rds_utils";
1918
import { Messages } from "../utils/messages";
2019
import { logger } from "../../logutils";
2120
import { AwsWrapperError } from "../utils/errors";
2221
import { HostInfo } from "../host_info";
2322
import { AwsCredentialsManager } from "./aws_credentials_manager";
2423
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
25-
import { WrapperProperties } from "../wrapper_property";
24+
import { WrapperProperties, WrapperProperty } from "../wrapper_property";
2625
import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
2726
import { ClientWrapper } from "../client_wrapper";
27+
import { RegionUtils } from "../utils/region_utils";
2828

2929
export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
3030
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
3131
protected static readonly tokenCache = new Map<string, TokenInfo>();
3232
private readonly telemetryFactory;
3333
private readonly fetchTokenCounter;
3434
private pluginService: PluginService;
35-
rdsUtil: RdsUtils = new RdsUtils();
3635

3736
constructor(pluginService: PluginService) {
3837
super();
@@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
7574
}
7675

7776
const host = IamAuthUtils.getIamHost(props, hostInfo);
78-
const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props);
77+
const region: string = RegionUtils.getRegion(props, WrapperProperties.IAM_REGION.name, host);
7978
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
8079
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
8180
if (tokenExpirationSec < 0) {

common/lib/connection_plugin_chain_builder.ts

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
3838
import { ConnectionPluginFactory } from "./plugin_factory";
3939
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
4040
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
41+
import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory";
4142
import { ConfigurationProfile } from "./profile/configuration_profile";
4243

4344
/*
@@ -53,6 +54,7 @@ export class ConnectionPluginChainBuilder {
5354
static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1;
5455

5556
static readonly PLUGIN_FACTORIES = new Map<string, PluginFactoryInfo>([
57+
["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }],
5658
["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }],
5759
["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }],
5860
["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }],

common/lib/host_list_provider_service.ts

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect";
2121
import { HostInfoBuilder } from "./host_info_builder";
2222
import { ConnectionUrlParser } from "./utils/connection_url_parser";
2323
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
24+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
2425

2526
export interface HostListProviderService {
2627
getHostListProvider(): HostListProvider | null;
@@ -50,4 +51,6 @@ export interface HostListProviderService {
5051
isClientValid(targetClient: any): Promise<boolean>;
5152

5253
getTelemetryFactory(): TelemetryFactory;
54+
55+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void;
5356
}

common/lib/plugin_service.ts

+70-6
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes"
4343
import { getWriter, logTopology } from "./utils/utils";
4444
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
4545
import { DriverDialect } from "./driver_dialect/driver_dialect";
46-
import { ConfigurationProfile } from "./profile/configuration_profile";
47-
import { SessionState } from "./session_state";
46+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
4847

4948
export class PluginService implements ErrorHandler, HostListProviderService {
49+
private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes
5050
private readonly _currentClient: AwsClient;
5151
private _currentHostInfo?: HostInfo;
5252
private _hostListProvider?: HostListProvider;
@@ -61,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
6161
protected readonly sessionStateService: SessionStateService;
6262
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();
6363
readonly props: Map<string, any>;
64+
private allowedAndBlockedHosts: AllowedAndBlockedHosts;
6465

6566
constructor(
6667
container: PluginServiceManagerContainer,
@@ -116,17 +117,29 @@ export class PluginService implements ErrorHandler, HostListProviderService {
116117
this._currentHostInfo = this._initialConnectionHostInfo;
117118

118119
if (!this._currentHostInfo) {
119-
if (this.getHosts().length === 0) {
120+
if (this.getAllHosts().length === 0) {
120121
throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty"));
121122
}
122123

123-
const writerHost = getWriter(this.getHosts());
124+
const writerHost = getWriter(this.getAllHosts());
125+
if (!this.hosts.includes(writerHost)) {
126+
throw new AwsWrapperError(
127+
Messages.get(
128+
"PluginServiceImpl.currentHostNotAllowed",
129+
this._currentHostInfo ? "<null>" : this._currentHostInfo.host,
130+
logTopology(this.hosts, "")
131+
)
132+
);
133+
}
134+
124135
if (writerHost) {
125136
this._currentHostInfo = writerHost;
126137
} else {
127138
this._currentHostInfo = this.getHosts()[0];
128139
}
129140
}
141+
142+
logger.debug(`Set current host to: ${this._currentHostInfo.host}`);
130143
}
131144

132145
return this._currentHostInfo;
@@ -260,11 +273,58 @@ export class PluginService implements ErrorHandler, HostListProviderService {
260273
}
261274
}
262275

263-
getHosts(): HostInfo[] {
276+
getAllHosts(): HostInfo[] {
264277
return this.hosts;
265278
}
266279

267-
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}
280+
getHosts(): HostInfo[] {
281+
const hostPermissions = this.allowedAndBlockedHosts;
282+
if (!hostPermissions) {
283+
return this.hosts;
284+
}
285+
286+
let hosts = this.hosts;
287+
const allowedHostIds = hostPermissions.getAllowedHostIds();
288+
const blockedHostIds = hostPermissions.getBlockedHostIds();
289+
290+
if (allowedHostIds && allowedHostIds.size > 0) {
291+
hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId));
292+
}
293+
294+
if (blockedHostIds && blockedHostIds.size > 0) {
295+
hosts = hosts.filter((host: HostInfo) => blockedHostIds.has(host.hostId));
296+
}
297+
298+
return hosts;
299+
}
300+
301+
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {
302+
if (hostAliases.size === 0) {
303+
return;
304+
}
305+
306+
const hostsToChange = [...new Set(this.getAllHosts().filter((host: HostInfo) => hostAliases.has(host.asAlias) || host.aliases))];
307+
308+
if (hostsToChange.length === 0) {
309+
logger.debug(Messages.get("PluginServiceImpl.hostsChangelistEmpty"));
310+
return;
311+
}
312+
313+
const changes = new Map<string, Set<HostChangeOptions>>();
314+
for (const host of hostsToChange) {
315+
const currentAvailability = host.getAvailability();
316+
PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO);
317+
if (currentAvailability !== availability) {
318+
let hostChanges = new Set<HostChangeOptions>();
319+
if (availability === HostAvailability.AVAILABLE) {
320+
hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]);
321+
} else {
322+
hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]);
323+
}
324+
changes.set(host.url, hostChanges);
325+
}
326+
}
327+
}
268328

269329
updateConfigWithProperties(props: Map<string, any>) {
270330
this._currentClient.config = Object.fromEntries(props.entries());
@@ -501,4 +561,8 @@ export class PluginService implements ErrorHandler, HostListProviderService {
501561
attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void {
502562
this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper);
503563
}
564+
565+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) {
566+
this.allowedAndBlockedHosts = allowedAndBlockedHosts;
567+
}
504568
}

common/lib/plugins/aurora_initial_connection_strategy_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
261261
}
262262

263263
private getWriter(): HostInfo | null {
264-
return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null;
264+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null;
265265
}
266266

267267
private getReader(props: Map<string, any>): HostInfo | undefined {
@@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
278278
}
279279

280280
private hasNoReaders(): boolean {
281-
return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined;
281+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined;
282282
}
283283
}

common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
9999
}
100100

101101
private async checkWriterChanged(): Promise<void> {
102-
const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts());
102+
const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts());
103103
if (this.currentWriter === null) {
104104
this.currentWriter = hostInfoAfterFailover;
105105
this.needUpdateCurrentWriter = false;
@@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
114114

115115
private rememberWriter(): void {
116116
if (this.currentWriter === null || this.needUpdateCurrentWriter) {
117-
this.currentWriter = this.getWriter(this.pluginService.getHosts());
117+
this.currentWriter = this.getWriter(this.pluginService.getAllHosts());
118118
this.needUpdateCurrentWriter = false;
119119
}
120120
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
import { CustomEndpointRoleType, customEndpointRoleTypeFromValue } from "./custom_endpoint_role_type";
18+
import { MemberListType } from "./member_list_type";
19+
20+
export class CustomEndpointInfo {
21+
private endpointIdentifier: string; // ID portion of the custom endpoint URL.
22+
private readonly clusterIdentifier: string; // ID of the cluster that the custom endpoint belongs to.
23+
private readonly url: string;
24+
private readonly roleType: CustomEndpointRoleType;
25+
26+
// A given custom endpoint will either specify a static list or an exclusion list, as indicated by `memberListType`.
27+
// If the list is a static list, 'members' specifies instances included in the custom endpoint, and new cluster
28+
// instances will not be automatically added to the custom endpoint. If it is an exclusion list, 'members' specifies
29+
// instances excluded by the custom endpoint, and new cluster instances will be added to the custom endpoint.
30+
private readonly memberListType: MemberListType;
31+
private readonly members: Set<string>;
32+
33+
constructor(
34+
endpointIdentifier: string,
35+
clusterIdentifier: string,
36+
url: string,
37+
roleType: CustomEndpointRoleType,
38+
members: Set<string>,
39+
memberListType: MemberListType
40+
) {
41+
this.endpointIdentifier = endpointIdentifier;
42+
this.clusterIdentifier = clusterIdentifier;
43+
this.url = url;
44+
this.roleType = roleType;
45+
this.members = members;
46+
this.memberListType = memberListType;
47+
}
48+
49+
getMemberListType(): MemberListType {
50+
return this.memberListType;
51+
}
52+
53+
static fromDbClusterEndpoint(responseEndpointInfo: any): CustomEndpointInfo {
54+
let members: Set<string>;
55+
let memberListType: MemberListType;
56+
57+
if (responseEndpointInfo.StaticMembers) {
58+
members = responseEndpointInfo.StaticMembers;
59+
memberListType = MemberListType.STATIC_LIST;
60+
} else {
61+
members = responseEndpointInfo.ExcludedMembers;
62+
memberListType = MemberListType.EXCLUSION_LIST;
63+
}
64+
65+
return new CustomEndpointInfo(
66+
responseEndpointInfo.DBClusterEndpointIdentifier,
67+
responseEndpointInfo.DBClusterIdentifier,
68+
responseEndpointInfo.Endpoint,
69+
customEndpointRoleTypeFromValue(responseEndpointInfo.CustomEndpointType),
70+
members,
71+
memberListType
72+
);
73+
}
74+
75+
getStaticMembers(): Set<string> {
76+
return this.memberListType === MemberListType.STATIC_LIST ? this.members : null;
77+
}
78+
79+
getExcludedMembers(): Set<string> {
80+
return this.memberListType === MemberListType.EXCLUSION_LIST ? this.members : null;
81+
}
82+
83+
equals(obj: any): boolean {
84+
if (!obj) {
85+
return false;
86+
}
87+
88+
if (obj === this) {
89+
return true;
90+
}
91+
92+
const info = obj as CustomEndpointInfo;
93+
return (
94+
this.endpointIdentifier === info.endpointIdentifier &&
95+
this.clusterIdentifier === info.clusterIdentifier &&
96+
this.url === info.url &&
97+
this.roleType === info.roleType &&
98+
this.members === info.members &&
99+
this.memberListType === info.memberListType
100+
);
101+
}
102+
103+
toString(): string {
104+
return `CustomEndpointInfo[url=${this.url}, clusterIdentifier=${this.clusterIdentifier}, customEndpointType=${this.roleType}, memberListType=${this.memberListType}, members=${this.members}]`;
105+
}
106+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
export interface CustomEndpointMonitor {
18+
shouldDispose(): boolean;
19+
hasCustomEndpointInfo(): boolean;
20+
close(): Promise<void>;
21+
}

0 commit comments

Comments
 (0)