Skip to content

Commit 3d1a9c6

Browse files
feat: custom endpoint plugin
1 parent c6478cb commit 3d1a9c6

File tree

53 files changed

+2021
-204
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2021
-204
lines changed

common/lib/AllowedAndBlockedHosts.ts

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
export class AllowedAndBlockedHosts {
18+
private readonly allowedHostIds: Set<string>;
19+
private readonly blockedHostIds: Set<string>;
20+
21+
constructor(allowedHostIds: Set<string>, blockedHostIds: Set<string>) {
22+
this.allowedHostIds = allowedHostIds;
23+
this.blockedHostIds = blockedHostIds;
24+
}
25+
26+
getAllowedHostIds() {
27+
return this.allowedHostIds;
28+
}
29+
30+
getBlockedHostIds() {
31+
return this.blockedHostIds;
32+
}
33+
}

common/lib/authentication/iam_authentication_plugin.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,23 @@
1515
*/
1616

1717
import { PluginService } from "../plugin_service";
18-
import { RdsUtils } from "../utils/rds_utils";
1918
import { Messages } from "../utils/messages";
2019
import { logger } from "../../logutils";
2120
import { AwsWrapperError } from "../utils/errors";
2221
import { HostInfo } from "../host_info";
2322
import { AwsCredentialsManager } from "./aws_credentials_manager";
2423
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
25-
import { WrapperProperties } from "../wrapper_property";
24+
import { WrapperProperties, WrapperProperty } from "../wrapper_property";
2625
import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
2726
import { ClientWrapper } from "../client_wrapper";
27+
import { RegionUtils } from "../utils/region_utils";
2828

2929
export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
3030
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
3131
protected static readonly tokenCache = new Map<string, TokenInfo>();
3232
private readonly telemetryFactory;
3333
private readonly fetchTokenCounter;
3434
private pluginService: PluginService;
35-
rdsUtil: RdsUtils = new RdsUtils();
3635

3736
constructor(pluginService: PluginService) {
3837
super();
@@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
7574
}
7675

7776
const host = IamAuthUtils.getIamHost(props, hostInfo);
78-
const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props);
77+
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
7978
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
8079
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
8180
if (tokenExpirationSec < 0) {

common/lib/connection_plugin_chain_builder.ts

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
3939
import { ConnectionPluginFactory } from "./plugin_factory";
4040
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
4141
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
42+
import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory";
4243
import { ConfigurationProfile } from "./profile/configuration_profile";
4344

4445
/*
@@ -54,6 +55,7 @@ export class ConnectionPluginChainBuilder {
5455
static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1;
5556

5657
static readonly PLUGIN_FACTORIES = new Map<string, PluginFactoryInfo>([
58+
["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }],
5759
["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }],
5860
["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }],
5961
["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }],

common/lib/database_dialect/database_dialect.ts

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { FailoverRestriction } from "../plugins/failover/failover_restriction";
2121
import { ErrorHandler } from "../error_handler";
2222
import { SessionState } from "../session_state";
2323
import { TransactionIsolationLevel } from "../utils/transaction_isolation_level";
24+
import { HostRole } from "../host_role";
2425

2526
export enum DatabaseType {
2627
MYSQL,
@@ -39,6 +40,7 @@ export interface DatabaseDialect {
3940
getSetSchemaQuery(schema: string): string;
4041
getDialectUpdateCandidates(): string[];
4142
getErrorHandler(): ErrorHandler;
43+
getHostRole(targetClient: ClientWrapper): Promise<HostRole>;
4244
isDialect(targetClient: ClientWrapper): Promise<boolean>;
4345
getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider;
4446
isClientValid(targetClient: ClientWrapper): Promise<boolean>;

common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import { Messages } from "../../utils/messages";
2727
import { WrapperProperties } from "../../wrapper_property";
2828
import { BlockingHostListProvider } from "../host_list_provider";
2929
import { logger } from "../../../logutils";
30+
import { isDialectTopologyAware } from "../../utils/utils";
3031

3132
export class MonitoringRdsHostListProvider extends RdsHostListProvider implements BlockingHostListProvider {
3233
static readonly CACHE_CLEANUP_NANOS: bigint = BigInt(60_000_000_000); // 1 minute.
@@ -76,7 +77,7 @@ export class MonitoringRdsHostListProvider extends RdsHostListProvider implement
7677

7778
async sqlQueryForTopology(targetClient: ClientWrapper): Promise<HostInfo[]> {
7879
const dialect: DatabaseDialect = this.hostListProviderService.getDialect();
79-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
80+
if (!isDialectTopologyAware(dialect)) {
8081
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
8182
}
8283
return await dialect.queryForTopology(targetClient, this).then((res: any) => this.processQueryResults(res));

common/lib/host_list_provider/rds_host_list_provider.ts

+5-10
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ import { WrapperProperties } from "../wrapper_property";
2727
import { logger } from "../../logutils";
2828
import { HostAvailability } from "../host_availability/host_availability";
2929
import { CacheMap } from "../utils/cache_map";
30-
import { logTopology } from "../utils/utils";
31-
import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect";
30+
import { isDialectTopologyAware, logTopology } from "../utils/utils";
3231
import { DatabaseDialect } from "../database_dialect/database_dialect";
3332
import { ClientWrapper } from "../client_wrapper";
3433

@@ -137,7 +136,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
137136
}
138137

139138
async getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise<HostRole> {
140-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
139+
if (!isDialectTopologyAware(dialect)) {
141140
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
142141
}
143142

@@ -150,7 +149,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
150149

151150
async getWriterId(client: ClientWrapper): Promise<string | null> {
152151
const dialect = this.hostListProviderService.getDialect();
153-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
152+
if (!isDialectTopologyAware(dialect)) {
154153
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
155154
}
156155

@@ -162,7 +161,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
162161
}
163162

164163
async identifyConnection(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo | null> {
165-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
164+
if (!isDialectTopologyAware(dialect)) {
166165
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
167166
}
168167
const instanceName = await dialect.identifyConnection(targetClient);
@@ -276,12 +275,8 @@ export class RdsHostListProvider implements DynamicHostListProvider {
276275
}
277276
}
278277

279-
protected isTopologyAwareDatabaseDialect(arg: any): arg is TopologyAwareDatabaseDialect {
280-
return arg;
281-
}
282-
283278
async queryForTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo[]> {
284-
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
279+
if (!isDialectTopologyAware(dialect)) {
285280
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
286281
}
287282

common/lib/host_list_provider_service.ts

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect";
2121
import { HostInfoBuilder } from "./host_info_builder";
2222
import { ConnectionUrlParser } from "./utils/connection_url_parser";
2323
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
24+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
2425

2526
export interface HostListProviderService {
2627
getHostListProvider(): HostListProvider | null;
@@ -50,4 +51,6 @@ export interface HostListProviderService {
5051
isClientValid(targetClient: any): Promise<boolean>;
5152

5253
getTelemetryFactory(): TelemetryFactory;
54+
55+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void;
5356
}

common/lib/plugin_service.ts

+83-6
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ import { ClientWrapper } from "./client_wrapper";
4040
import { logger } from "../logutils";
4141
import { Messages } from "./utils/messages";
4242
import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes";
43-
import { getWriter } from "./utils/utils";
43+
import { getWriter, logTopology } from "./utils/utils";
4444
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
4545
import { DriverDialect } from "./driver_dialect/driver_dialect";
46+
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";
4647

4748
export class PluginService implements ErrorHandler, HostListProviderService {
49+
private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes
4850
private readonly _currentClient: AwsClient;
4951
private _currentHostInfo?: HostInfo;
5052
private _hostListProvider?: HostListProvider;
@@ -59,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
5961
protected readonly sessionStateService: SessionStateService;
6062
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();
6163
readonly props: Map<string, any>;
64+
private allowedAndBlockedHosts: AllowedAndBlockedHosts | null = null;
6265

6366
constructor(
6467
container: PluginServiceManagerContainer,
@@ -114,17 +117,34 @@ export class PluginService implements ErrorHandler, HostListProviderService {
114117
this._currentHostInfo = this._initialConnectionHostInfo;
115118

116119
if (!this._currentHostInfo) {
117-
if (this.getHosts().length === 0) {
120+
if (this.getAllHosts().length === 0) {
118121
throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty"));
119122
}
120123

121-
const writerHost = getWriter(this.getHosts());
124+
const writerHost = getWriter(this.getAllHosts());
122125
if (writerHost) {
123126
this._currentHostInfo = writerHost;
124-
} else {
127+
if (!this.getHosts().some((hostInfo: HostInfo) => hostInfo.host === writerHost?.host)) {
128+
throw new AwsWrapperError(
129+
Messages.get(
130+
"PluginService.currentHostNotAllowed",
131+
this._currentHostInfo ? this._currentHostInfo.host : "<null>",
132+
logTopology(this.hosts, "[PluginService.currentHostNotAllowed] ")
133+
)
134+
);
135+
}
136+
}
137+
138+
if (!this._currentHostInfo) {
125139
this._currentHostInfo = this.getHosts()[0];
126140
}
127141
}
142+
143+
if (!this._currentHostInfo) {
144+
throw new AwsWrapperError(Messages.get("PluginService.currentHostNotDefined"));
145+
}
146+
147+
logger.debug(`Set current host to: ${this._currentHostInfo.host}`);
128148
}
129149

130150
return this._currentHostInfo;
@@ -286,11 +306,64 @@ export class PluginService implements ErrorHandler, HostListProviderService {
286306
}
287307
}
288308

289-
getHosts(): HostInfo[] {
309+
getAllHosts(): HostInfo[] {
290310
return this.hosts;
291311
}
292312

293-
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}
313+
getHosts(): HostInfo[] {
314+
const hostPermissions = this.allowedAndBlockedHosts;
315+
if (!hostPermissions) {
316+
return this.hosts;
317+
}
318+
319+
let hosts = this.hosts;
320+
const allowedHostIds = hostPermissions.getAllowedHostIds();
321+
const blockedHostIds = hostPermissions.getBlockedHostIds();
322+
323+
if (allowedHostIds && allowedHostIds.size > 0) {
324+
hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId));
325+
}
326+
327+
if (blockedHostIds && blockedHostIds.size > 0) {
328+
hosts = hosts.filter((host: HostInfo) => !blockedHostIds.has(host.hostId));
329+
}
330+
331+
return hosts;
332+
}
333+
334+
setAvailability(hostAliases: Set<string>, availability: HostAvailability) {
335+
if (hostAliases.size === 0) {
336+
return;
337+
}
338+
339+
const hostsToChange = [
340+
...new Set(
341+
this.getAllHosts().filter(
342+
(host: HostInfo) => hostAliases.has(host.asAlias) || [...host.aliases].some((hostAlias: string) => hostAliases.has(hostAlias))
343+
)
344+
)
345+
];
346+
347+
if (hostsToChange.length === 0) {
348+
logger.debug(Messages.get("PluginService.hostsChangeListEmpty"));
349+
return;
350+
}
351+
352+
const changes = new Map<string, Set<HostChangeOptions>>();
353+
for (const host of hostsToChange) {
354+
const currentAvailability = host.getAvailability();
355+
PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO);
356+
if (currentAvailability !== availability) {
357+
let hostChanges = new Set<HostChangeOptions>();
358+
if (availability === HostAvailability.AVAILABLE) {
359+
hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]);
360+
} else {
361+
hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]);
362+
}
363+
changes.set(host.url, hostChanges);
364+
}
365+
}
366+
}
294367

295368
updateConfigWithProperties(props: Map<string, any>) {
296369
this._currentClient.config = Object.fromEntries(props.entries());
@@ -527,4 +600,8 @@ export class PluginService implements ErrorHandler, HostListProviderService {
527600
attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void {
528601
this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper);
529602
}
603+
604+
setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) {
605+
this.allowedAndBlockedHosts = allowedAndBlockedHosts;
606+
}
530607
}

common/lib/plugins/aurora_initial_connection_strategy_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
261261
}
262262

263263
private getWriter(): HostInfo | null {
264-
return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null;
264+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null;
265265
}
266266

267267
private getReader(props: Map<string, any>): HostInfo | undefined {
@@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
278278
}
279279

280280
private hasNoReaders(): boolean {
281-
return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined;
281+
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined;
282282
}
283283
}

common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
9999
}
100100

101101
private async checkWriterChanged(): Promise<void> {
102-
const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts());
102+
const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts());
103103
if (this.currentWriter === null) {
104104
this.currentWriter = hostInfoAfterFailover;
105105
this.needUpdateCurrentWriter = false;
@@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
114114

115115
private rememberWriter(): void {
116116
if (this.currentWriter === null || this.needUpdateCurrentWriter) {
117-
this.currentWriter = this.getWriter(this.pluginService.getHosts());
117+
this.currentWriter = this.getWriter(this.pluginService.getAllHosts());
118118
this.needUpdateCurrentWriter = false;
119119
}
120120
}

0 commit comments

Comments
 (0)