Skip to content

Commit ef4288d

Browse files
feat: custom endpoint plugin
1 parent 132fd5d commit ef4288d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1873
-90
lines changed

common/lib/AllowedAndBlockedHosts.ts

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
export class AllowedAndBlockedHosts {
18+
private readonly allowedHostIds: Set<string>;
19+
private readonly blockedHostIds: Set<string>;
20+
21+
constructor(allowedHostIds: Set<string>, blockedHostIds: Set<string>) {
22+
this.allowedHostIds = allowedHostIds;
23+
this.blockedHostIds = blockedHostIds;
24+
}
25+
26+
getAllowedHostIds() {
27+
return this.allowedHostIds;
28+
}
29+
30+
getBlockedHostIds() {
31+
return this.blockedHostIds;
32+
}
33+
}

common/lib/authentication/iam_authentication_plugin.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,23 @@
1515
*/
1616

1717
import { PluginService } from "../plugin_service";
18-
import { RdsUtils } from "../utils/rds_utils";
1918
import { Messages } from "../utils/messages";
2019
import { logger } from "../../logutils";
2120
import { AwsWrapperError } from "../utils/errors";
2221
import { HostInfo } from "../host_info";
2322
import { AwsCredentialsManager } from "./aws_credentials_manager";
2423
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
25-
import { WrapperProperties } from "../wrapper_property";
24+
import { WrapperProperties, WrapperProperty } from "../wrapper_property";
2625
import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
2726
import { ClientWrapper } from "../client_wrapper";
27+
import { RegionUtils } from "../utils/region_utils";
2828

2929
export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
3030
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
3131
protected static readonly tokenCache = new Map<string, TokenInfo>();
3232
private readonly telemetryFactory;
3333
private readonly fetchTokenCounter;
3434
private pluginService: PluginService;
35-
rdsUtil: RdsUtils = new RdsUtils();
3635

3736
constructor(pluginService: PluginService) {
3837
super();
@@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
7574
}
7675

7776
const host = IamAuthUtils.getIamHost(props, hostInfo);
78-
const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props);
77+
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
7978
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
8079
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
8180
if (tokenExpirationSec < 0) {

common/lib/connection_plugin_chain_builder.ts

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
3838
import { ConnectionPluginFactory } from "./plugin_factory";
3939
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
4040
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
41+
import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory";
4142
import { ConfigurationProfile } from "./profile/configuration_profile";
4243

4344
/*
@@ -53,6 +54,7 @@ export class ConnectionPluginChainBuilder {
5354
static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1;
5455

5556
static readonly PLUGIN_FACTORIES = new Map<string, PluginFactoryInfo>([
57+
["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }],
5658
["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }],
5759
["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }],
5860
["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }],

common/lib/database_dialect/database_dialect.ts

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { FailoverRestriction } from "../plugins/failover/failover_restriction";
2121
import { ErrorHandler } from "../error_handler";
2222
import { SessionState } from "../session_state";
2323
import { TransactionIsolationLevel } from "../utils/transaction_isolation_level";
24+
import { HostRole } from "../host_role";
2425

2526
export enum DatabaseType {
2627
MYSQL,
@@ -39,6 +40,7 @@ export interface DatabaseDialect {
3940
getSetSchemaQuery(schema: string): string;
4041
getDialectUpdateCandidates(): string[];
4142
getErrorHandler(): ErrorHandler;
43+
getHostRole(targetClient: ClientWrapper): Promise<HostRole>;
4244
isDialect(targetClient: ClientWrapper): Promise<boolean>;
4345
getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider;
4446
isClientValid(targetClient: ClientWrapper): Promise<boolean>;

common/lib/host_list_provider/rds_host_list_provider.ts

+4-8
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import { WrapperProperties } from "../wrapper_property";
2828
import { logger } from "../../logutils";
2929
import { HostAvailability } from "../host_availability/host_availability";
3030
import { CacheMap } from "../utils/cache_map";
31-
import { logTopology } from "../utils/utils";
31+
import { isDialectTopologyAware, logTopology } from "../utils/utils";
3232
import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect";
3333
import { DatabaseDialect } from "../database_dialect/database_dialect";
3434
import { ClientWrapper } from "../client_wrapper";
@@ -138,7 +138,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
138138
}
139139

140140
async getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise<HostRole> {
141-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
141+
if (!isDialectTopologyAware(dialect)) {
142142
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
143143
}
144144

@@ -150,7 +150,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
150150
}
151151

152152
async identifyConnection(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo | null> {
153-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
153+
if (!isDialectTopologyAware(dialect)) {
154154
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
155155
}
156156
const instanceName = await dialect.identifyConnection(targetClient);
@@ -264,12 +264,8 @@ export class RdsHostListProvider implements DynamicHostListProvider {
264264
}
265265
}
266266

267-
protected isTopologyAwareDatabaseDialect(arg: any): arg is TopologyAwareDatabaseDialect {
268-
return arg;
269-
}
270-
271267
async queryForTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo[]> {
272-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
268+
if (!isDialectTopologyAware(dialect)) {
273269
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
274270
}
275271

common/lib/host_list_provider_service.ts

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect";
2121
import { HostInfoBuilder } from "./host_info_builder";
2222
import { ConnectionUrlParser } from "./utils/connection_url_parser";
2323
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
24+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
2425

2526
export interface HostListProviderService {
2627
getHostListProvider(): HostListProvider | null;
@@ -50,4 +51,6 @@ export interface HostListProviderService {
5051
isClientValid(targetClient: any): Promise<boolean>;
5152

5253
getTelemetryFactory(): TelemetryFactory;
54+
55+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void;
5356
}

common/lib/plugin_service.ts

+76-6
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes"
4343
import { getWriter, logTopology } from "./utils/utils";
4444
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
4545
import { DriverDialect } from "./driver_dialect/driver_dialect";
46-
import { ConfigurationProfile } from "./profile/configuration_profile";
47-
import { SessionState } from "./session_state";
46+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
4847

4948
export class PluginService implements ErrorHandler, HostListProviderService {
49+
private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes
5050
private readonly _currentClient: AwsClient;
5151
private _currentHostInfo?: HostInfo;
5252
private _hostListProvider?: HostListProvider;
@@ -61,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
6161
protected readonly sessionStateService: SessionStateService;
6262
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();
6363
readonly props: Map<string, any>;
64+
private allowedAndBlockedHosts: AllowedAndBlockedHosts;
6465

6566
constructor(
6667
container: PluginServiceManagerContainer,
@@ -116,17 +117,29 @@ export class PluginService implements ErrorHandler, HostListProviderService {
116117
this._currentHostInfo = this._initialConnectionHostInfo;
117118

118119
if (!this._currentHostInfo) {
119-
if (this.getHosts().length === 0) {
120+
if (this.getAllHosts().length === 0) {
120121
throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty"));
121122
}
122123

123-
const writerHost = getWriter(this.getHosts());
124+
const writerHost = getWriter(this.getAllHosts());
125+
if (!this.hosts.includes(writerHost)) {
126+
throw new AwsWrapperError(
127+
Messages.get(
128+
"PluginServiceImpl.currentHostNotAllowed",
129+
this._currentHostInfo ? "<null>" : this._currentHostInfo.host,
130+
logTopology(this.hosts, "")
131+
)
132+
);
133+
}
134+
124135
if (writerHost) {
125136
this._currentHostInfo = writerHost;
126137
} else {
127138
this._currentHostInfo = this.getHosts()[0];
128139
}
129140
}
141+
142+
logger.debug(`Set current host to: ${this._currentHostInfo.host}`);
130143
}
131144

132145
return this._currentHostInfo;
@@ -260,11 +273,64 @@ export class PluginService implements ErrorHandler, HostListProviderService {
260273
}
261274
}
262275

263-
getHosts(): HostInfo[] {
276+
getAllHosts(): HostInfo[] {
264277
return this.hosts;
265278
}
266279

267-
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}
280+
getHosts(): HostInfo[] {
281+
const hostPermissions = this.allowedAndBlockedHosts;
282+
if (!hostPermissions) {
283+
return this.hosts;
284+
}
285+
286+
let hosts = this.hosts;
287+
const allowedHostIds = hostPermissions.getAllowedHostIds();
288+
const blockedHostIds = hostPermissions.getBlockedHostIds();
289+
290+
if (allowedHostIds && allowedHostIds.size > 0) {
291+
hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId));
292+
}
293+
294+
if (blockedHostIds && blockedHostIds.size > 0) {
295+
hosts = hosts.filter((host: HostInfo) => !blockedHostIds.has(host.hostId));
296+
}
297+
298+
return hosts;
299+
}
300+
301+
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {
302+
if (hostAliases.size === 0) {
303+
return;
304+
}
305+
306+
const hostsToChange = [
307+
...new Set(
308+
this.getAllHosts().filter(
309+
(host: HostInfo) => hostAliases.has(host.asAlias) || [...host.aliases].some((hostAlias: string) => hostAliases.has(hostAlias))
310+
)
311+
)
312+
];
313+
314+
if (hostsToChange.length === 0) {
315+
logger.debug(Messages.get("PluginService.hostsChangeListEmpty"));
316+
return;
317+
}
318+
319+
const changes = new Map<string, Set<HostChangeOptions>>();
320+
for (const host of hostsToChange) {
321+
const currentAvailability = host.getAvailability();
322+
PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO);
323+
if (currentAvailability !== availability) {
324+
let hostChanges = new Set<HostChangeOptions>();
325+
if (availability === HostAvailability.AVAILABLE) {
326+
hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]);
327+
} else {
328+
hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]);
329+
}
330+
changes.set(host.url, hostChanges);
331+
}
332+
}
333+
}
268334

269335
updateConfigWithProperties(props: Map<string, any>) {
270336
this._currentClient.config = Object.fromEntries(props.entries());
@@ -501,4 +567,8 @@ export class PluginService implements ErrorHandler, HostListProviderService {
501567
attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void {
502568
this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper);
503569
}
570+
571+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) {
572+
this.allowedAndBlockedHosts = allowedAndBlockedHosts;
573+
}
504574
}

common/lib/plugins/aurora_initial_connection_strategy_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
261261
}
262262

263263
private getWriter(): HostInfo | null {
264-
return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null;
264+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null;
265265
}
266266

267267
private getReader(props: Map<string, any>): HostInfo | undefined {
@@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
278278
}
279279

280280
private hasNoReaders(): boolean {
281-
return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined;
281+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined;
282282
}
283283
}

common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
9999
}
100100

101101
private async checkWriterChanged(): Promise<void> {
102-
const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts());
102+
const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts());
103103
if (this.currentWriter === null) {
104104
this.currentWriter = hostInfoAfterFailover;
105105
this.needUpdateCurrentWriter = false;
@@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
114114

115115
private rememberWriter(): void {
116116
if (this.currentWriter === null || this.needUpdateCurrentWriter) {
117-
this.currentWriter = this.getWriter(this.pluginService.getHosts());
117+
this.currentWriter = this.getWriter(this.pluginService.getAllHosts());
118118
this.needUpdateCurrentWriter = false;
119119
}
120120
}

0 commit comments

Comments
 (0)