diff --git a/pgcat.toml b/pgcat.toml index 87f2700c..71ef33d3 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -1,350 +1,33 @@ -# -# PgCat config example. -# +# This is an example of the most basic config +# that will mimic what PgBouncer does in transaction mode with one server. -# -# General pooler settings [general] -# What IP to run on, 0.0.0.0 means accessible from everywhere. -host = "0.0.0.0" - -# Port to run on, same as PgBouncer used in this example. -port = 6432 - -# Whether to enable prometheus exporter or not. -enable_prometheus_exporter = true - -# Port at which prometheus exporter listens on. -prometheus_exporter_port = 9930 - -# How long to wait before aborting a server connection (ms). -connect_timeout = 5000 # milliseconds - -# How long an idle connection with a server is left open (ms). -idle_timeout = 30000 # milliseconds - -# Max connection lifetime before it's closed, even if actively used. -server_lifetime = 86400000 # 24 hours - -# How long a client is allowed to be idle while in a transaction (ms). -idle_client_in_transaction_timeout = 0 # milliseconds - -# How much time to give the health check query to return with a result (ms). -healthcheck_timeout = 1000 # milliseconds - -# How long to keep connection available for immediate re-use, without running a healthcheck query on it -healthcheck_delay = 30000 # milliseconds - -# How much time to give clients during shutdown before forcibly killing client connections (ms). -shutdown_timeout = 60000 # milliseconds - -# How long to ban a server if it fails a health check (seconds). -ban_time = 60 # seconds - -# If we should log client connections -log_client_connections = false - -# If we should log client disconnections -log_client_disconnections = false - -# When set to true, PgCat reloads configs if it detects a change in the config file. -autoreload = 15000 - -# Number of worker threads the Runtime will use (4 by default). -worker_threads = 5 - -# Number of seconds of connection idleness to wait before sending a keepalive packet to the server. -tcp_keepalives_idle = 5 -# Number of unacknowledged keepalive packets allowed before giving up and closing the connection. -tcp_keepalives_count = 5 -# Number of seconds between keepalive packets. -tcp_keepalives_interval = 5 - -# Path to TLS Certificate file to use for TLS connections -# tls_certificate = ".circleci/server.cert" -# Path to TLS private key file to use for TLS connections -# tls_private_key = ".circleci/server.key" - -# Enable/disable server TLS -server_tls = false - -# Verify server certificate is completely authentic. -verify_server_certificate = false - -# User name to access the virtual administrative database (pgbouncer or pgcat) -# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. -admin_username = "admin_user" -# Password to access the virtual administrative database -admin_password = "admin_pass" - -# Default plugins that are configured on all pools. -[plugins] - -# Prewarmer plugin that runs queries on server startup, before giving the connection -# to the client. -[plugins.prewarmer] -enabled = false -queries = [ - "SELECT pg_prewarm('pgbench_accounts')", -] - -# Log all queries to stdout. -[plugins.query_logger] -enabled = false - -# Block access to tables that Postgres does not allow us to control. -[plugins.table_access] -enabled = false -tables = [ - "pg_user", - "pg_roles", - "pg_database", -] - -# Intercept user queries and give a fake reply. -[plugins.intercept] -enabled = true - -[plugins.intercept.queries.0] - -query = "select current_database() as a, current_schemas(false) as b" -schema = [ - ["a", "text"], - ["b", "text"], -] -result = [ - ["${DATABASE}", "{public}"], -] - -[plugins.intercept.queries.1] - -query = "select current_database(), current_schema(), current_user" -schema = [ - ["current_database", "text"], - ["current_schema", "text"], - ["current_user", "text"], -] -result = [ - ["${DATABASE}", "public", "${USER}"], -] - - -# pool configs are structured as pool. -# the pool_name is what clients use as database name when connecting. -# For a pool named `sharded_db`, clients access that pool using connection string like -# `postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db` -[pools.sharded_db] -# Pool mode (see PgBouncer docs for more). -# `session` one server connection per connected client -# `transaction` one server connection per client transaction -pool_mode = "transaction" - -# Load balancing mode -# `random` selects the server at random -# `loc` selects the server with the least outstanding busy conncetions -load_balancing_mode = "random" - -# If the client doesn't specify, PgCat routes traffic to this role by default. -# `any` round-robin between primary and replicas, -# `replica` round-robin between replicas only without touching the primary, -# `primary` all queries go to the primary unless otherwise specified. -default_role = "any" - -# Prepared statements cache size. -# TODO: update documentation -prepared_statements_cache_size = 500 -# If Query Parser is enabled, we'll attempt to parse -# every incoming query to determine if it's a read or a write. -# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, -# we'll direct it to the primary. -query_parser_enabled = true +log_level = "DEBUG" -# If the query parser is enabled and this setting is enabled, we'll attempt to -# infer the role from the query itself. -query_parser_read_write_splitting = true - -# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for -# load balancing of read queries. Otherwise, the primary will only be used for write -# queries. The primary can always be explicitly selected with our custom protocol. -primary_reads_enabled = true - -# Allow sharding commands to be passed as statement comments instead of -# separate commands. If these are unset this functionality is disabled. -# sharding_key_regex = '/\* sharding_key: (\d+) \*/' -# shard_id_regex = '/\* shard_id: (\d+) \*/' -# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements - -# Defines the behavior when no shard is selected in a sharded system. -# `random`: picks a shard at random -# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors -# `shard_`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime -# default_shard = "shard_0" - -# So what if you wanted to implement a different hashing function, -# or you've already built one and you want this pooler to use it? -# Current options: -# `pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function) -# `sha1`: A hashing function based on SHA1 -sharding_function = "pg_bigint_hash" - -# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be -# established using the database configured in the pool. This parameter is inherited by every pool -# and can be redefined in pool configuration. -# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'" - -# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query -# specified in `auth_query_user`. The connection will be established using the database configured in the pool. -# This parameter is inherited by every pool and can be redefined in pool configuration. -# auth_query_user = "sharding_user" - -# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query -# specified in `auth_query_user`. The connection will be established using the database configured in the pool. -# This parameter is inherited by every pool and can be redefined in pool configuration. -# auth_query_password = "sharding_user" - -# Automatically parse this from queries and route queries to the right shard! -# automatic_sharding_key = "data.id" - -# Idle timeout can be overwritten in the pool -idle_timeout = 40000 - -# Connect timeout can be overwritten in the pool -connect_timeout = 3000 - -# When enabled, ip resolutions for server connections specified using hostnames will be cached -# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found -# old ip connections are closed (gracefully) and new connections will start using new ip. -# dns_cache_enabled = false - -# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). -# dns_max_ttl = 30 - -# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting, -# so all plugins have to be configured here again. -[pool.sharded_db.plugins] - -[pools.sharded_db.plugins.prewarmer] -enabled = true -queries = [ - "SELECT pg_prewarm('pgbench_accounts')", -] - -[pools.sharded_db.plugins.query_logger] -enabled = false - -[pools.sharded_db.plugins.table_access] -enabled = false -tables = [ - "pg_user", - "pg_roles", - "pg_database", -] - -[pools.sharded_db.plugins.intercept] -enabled = true - -[pools.sharded_db.plugins.intercept.queries.0] - -query = "select current_database() as a, current_schemas(false) as b" -schema = [ - ["a", "text"], - ["b", "text"], -] -result = [ - ["${DATABASE}", "{public}"], -] - -[pools.sharded_db.plugins.intercept.queries.1] - -query = "select current_database(), current_schema(), current_user" -schema = [ - ["current_database", "text"], - ["current_schema", "text"], - ["current_user", "text"], -] -result = [ - ["${DATABASE}", "public", "${USER}"], -] - -# User configs are structured as pool..users. -# This section holds the credentials for users that may connect to this cluster -[pools.sharded_db.users.0] -# PostgreSQL username used to authenticate the user and connect to the server -# if `server_username` is not set. -username = "sharding_user" - -# PostgreSQL password used to authenticate the user and connect to the server -# if `server_password` is not set. -password = "sharding_user" - -pool_mode = "transaction" - -# PostgreSQL username used to connect to the server. -# server_username = "another_user" - -# PostgreSQL password used to connect to the server. -# server_password = "another_password" - -# Maximum number of server connections that can be established for this user -# The maximum number of connection from a single Pgcat process to any database in the cluster -# is the sum of pool_size across all users. -pool_size = 9 - - -# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way. -# 0 means it is disabled. -statement_timeout = 0 - -[pools.sharded_db.users.1] -username = "other_user" -password = "other_user" -pool_size = 21 -statement_timeout = 15000 -connect_timeout = 1000 -idle_timeout = 1000 - -# Shard configs are structured as pool..shards. -# Each shard config contains a list of servers that make up the shard -# and the database name to use. -[pools.sharded_db.shards.0] -# Array of servers in the shard, each server entry is an array of `[host, port, role]` -servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]] - -# Array of mirrors for the shard, each mirror entry is an array of `[host, port, index of server in servers array]` -# Traffic hitting the server identified by the index will be sent to the mirror. -# mirrors = [["1.2.3.4", 5432, 0], ["1.2.3.4", 5432, 1]] - -# Database name (e.g. "postgres") -database = "shard0" +host = "0.0.0.0" +port = 6433 +admin_username = "pgcat" +admin_password = "pgcat" -[pools.sharded_db.shards.1] -servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]] -database = "shard1" +[pools.pgml] +auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'" -[pools.sharded_db.shards.2] -servers = [["127.0.0.1", 5432, "primary" ], ["localhost", 5432, "replica" ]] -database = "shard2" +# Be sure to grant this user LOGIN on Postgres +auth_query_user = "postgres" +auth_query_password = "postgres" +proxy = true -[pools.simple_db] +[pools.pgml.users.0] +username = "postgres" +#password = "postgres" +pool_size = 10 +min_pool_size = 1 pool_mode = "session" -default_role = "primary" -query_parser_enabled = true -primary_reads_enabled = true -sharding_function = "pg_bigint_hash" - -[pools.simple_db.users.0] -username = "simple_user" -password = "simple_user" -pool_size = 5 -min_pool_size = 3 -server_lifetime = 60000 -statement_timeout = 0 -[pools.simple_db.shards.0] +[pools.pgml.shards.0] servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ] + ["localhost", 5432, "primary"] ] -database = "some_db" +database = "postgres" diff --git a/pgcat.toml.original b/pgcat.toml.original new file mode 100644 index 00000000..87f2700c --- /dev/null +++ b/pgcat.toml.original @@ -0,0 +1,350 @@ +# +# PgCat config example. +# + +# +# General pooler settings +[general] +# What IP to run on, 0.0.0.0 means accessible from everywhere. +host = "0.0.0.0" + +# Port to run on, same as PgBouncer used in this example. +port = 6432 + +# Whether to enable prometheus exporter or not. +enable_prometheus_exporter = true + +# Port at which prometheus exporter listens on. +prometheus_exporter_port = 9930 + +# How long to wait before aborting a server connection (ms). +connect_timeout = 5000 # milliseconds + +# How long an idle connection with a server is left open (ms). +idle_timeout = 30000 # milliseconds + +# Max connection lifetime before it's closed, even if actively used. +server_lifetime = 86400000 # 24 hours + +# How long a client is allowed to be idle while in a transaction (ms). +idle_client_in_transaction_timeout = 0 # milliseconds + +# How much time to give the health check query to return with a result (ms). +healthcheck_timeout = 1000 # milliseconds + +# How long to keep connection available for immediate re-use, without running a healthcheck query on it +healthcheck_delay = 30000 # milliseconds + +# How much time to give clients during shutdown before forcibly killing client connections (ms). +shutdown_timeout = 60000 # milliseconds + +# How long to ban a server if it fails a health check (seconds). +ban_time = 60 # seconds + +# If we should log client connections +log_client_connections = false + +# If we should log client disconnections +log_client_disconnections = false + +# When set to true, PgCat reloads configs if it detects a change in the config file. +autoreload = 15000 + +# Number of worker threads the Runtime will use (4 by default). +worker_threads = 5 + +# Number of seconds of connection idleness to wait before sending a keepalive packet to the server. +tcp_keepalives_idle = 5 +# Number of unacknowledged keepalive packets allowed before giving up and closing the connection. +tcp_keepalives_count = 5 +# Number of seconds between keepalive packets. +tcp_keepalives_interval = 5 + +# Path to TLS Certificate file to use for TLS connections +# tls_certificate = ".circleci/server.cert" +# Path to TLS private key file to use for TLS connections +# tls_private_key = ".circleci/server.key" + +# Enable/disable server TLS +server_tls = false + +# Verify server certificate is completely authentic. +verify_server_certificate = false + +# User name to access the virtual administrative database (pgbouncer or pgcat) +# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. +admin_username = "admin_user" +# Password to access the virtual administrative database +admin_password = "admin_pass" + +# Default plugins that are configured on all pools. +[plugins] + +# Prewarmer plugin that runs queries on server startup, before giving the connection +# to the client. +[plugins.prewarmer] +enabled = false +queries = [ + "SELECT pg_prewarm('pgbench_accounts')", +] + +# Log all queries to stdout. +[plugins.query_logger] +enabled = false + +# Block access to tables that Postgres does not allow us to control. +[plugins.table_access] +enabled = false +tables = [ + "pg_user", + "pg_roles", + "pg_database", +] + +# Intercept user queries and give a fake reply. +[plugins.intercept] +enabled = true + +[plugins.intercept.queries.0] + +query = "select current_database() as a, current_schemas(false) as b" +schema = [ + ["a", "text"], + ["b", "text"], +] +result = [ + ["${DATABASE}", "{public}"], +] + +[plugins.intercept.queries.1] + +query = "select current_database(), current_schema(), current_user" +schema = [ + ["current_database", "text"], + ["current_schema", "text"], + ["current_user", "text"], +] +result = [ + ["${DATABASE}", "public", "${USER}"], +] + + +# pool configs are structured as pool. +# the pool_name is what clients use as database name when connecting. +# For a pool named `sharded_db`, clients access that pool using connection string like +# `postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db` +[pools.sharded_db] +# Pool mode (see PgBouncer docs for more). +# `session` one server connection per connected client +# `transaction` one server connection per client transaction +pool_mode = "transaction" + +# Load balancing mode +# `random` selects the server at random +# `loc` selects the server with the least outstanding busy conncetions +load_balancing_mode = "random" + +# If the client doesn't specify, PgCat routes traffic to this role by default. +# `any` round-robin between primary and replicas, +# `replica` round-robin between replicas only without touching the primary, +# `primary` all queries go to the primary unless otherwise specified. +default_role = "any" + +# Prepared statements cache size. +# TODO: update documentation +prepared_statements_cache_size = 500 + +# If Query Parser is enabled, we'll attempt to parse +# every incoming query to determine if it's a read or a write. +# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, +# we'll direct it to the primary. +query_parser_enabled = true + +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + +# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for +# load balancing of read queries. Otherwise, the primary will only be used for write +# queries. The primary can always be explicitly selected with our custom protocol. +primary_reads_enabled = true + +# Allow sharding commands to be passed as statement comments instead of +# separate commands. If these are unset this functionality is disabled. +# sharding_key_regex = '/\* sharding_key: (\d+) \*/' +# shard_id_regex = '/\* shard_id: (\d+) \*/' +# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements + +# Defines the behavior when no shard is selected in a sharded system. +# `random`: picks a shard at random +# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors +# `shard_`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime +# default_shard = "shard_0" + +# So what if you wanted to implement a different hashing function, +# or you've already built one and you want this pooler to use it? +# Current options: +# `pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function) +# `sha1`: A hashing function based on SHA1 +sharding_function = "pg_bigint_hash" + +# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be +# established using the database configured in the pool. This parameter is inherited by every pool +# and can be redefined in pool configuration. +# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'" + +# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +# specified in `auth_query_user`. The connection will be established using the database configured in the pool. +# This parameter is inherited by every pool and can be redefined in pool configuration. +# auth_query_user = "sharding_user" + +# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +# specified in `auth_query_user`. The connection will be established using the database configured in the pool. +# This parameter is inherited by every pool and can be redefined in pool configuration. +# auth_query_password = "sharding_user" + +# Automatically parse this from queries and route queries to the right shard! +# automatic_sharding_key = "data.id" + +# Idle timeout can be overwritten in the pool +idle_timeout = 40000 + +# Connect timeout can be overwritten in the pool +connect_timeout = 3000 + +# When enabled, ip resolutions for server connections specified using hostnames will be cached +# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found +# old ip connections are closed (gracefully) and new connections will start using new ip. +# dns_cache_enabled = false + +# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`). +# dns_max_ttl = 30 + +# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting, +# so all plugins have to be configured here again. +[pool.sharded_db.plugins] + +[pools.sharded_db.plugins.prewarmer] +enabled = true +queries = [ + "SELECT pg_prewarm('pgbench_accounts')", +] + +[pools.sharded_db.plugins.query_logger] +enabled = false + +[pools.sharded_db.plugins.table_access] +enabled = false +tables = [ + "pg_user", + "pg_roles", + "pg_database", +] + +[pools.sharded_db.plugins.intercept] +enabled = true + +[pools.sharded_db.plugins.intercept.queries.0] + +query = "select current_database() as a, current_schemas(false) as b" +schema = [ + ["a", "text"], + ["b", "text"], +] +result = [ + ["${DATABASE}", "{public}"], +] + +[pools.sharded_db.plugins.intercept.queries.1] + +query = "select current_database(), current_schema(), current_user" +schema = [ + ["current_database", "text"], + ["current_schema", "text"], + ["current_user", "text"], +] +result = [ + ["${DATABASE}", "public", "${USER}"], +] + +# User configs are structured as pool..users. +# This section holds the credentials for users that may connect to this cluster +[pools.sharded_db.users.0] +# PostgreSQL username used to authenticate the user and connect to the server +# if `server_username` is not set. +username = "sharding_user" + +# PostgreSQL password used to authenticate the user and connect to the server +# if `server_password` is not set. +password = "sharding_user" + +pool_mode = "transaction" + +# PostgreSQL username used to connect to the server. +# server_username = "another_user" + +# PostgreSQL password used to connect to the server. +# server_password = "another_password" + +# Maximum number of server connections that can be established for this user +# The maximum number of connection from a single Pgcat process to any database in the cluster +# is the sum of pool_size across all users. +pool_size = 9 + + +# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way. +# 0 means it is disabled. +statement_timeout = 0 + +[pools.sharded_db.users.1] +username = "other_user" +password = "other_user" +pool_size = 21 +statement_timeout = 15000 +connect_timeout = 1000 +idle_timeout = 1000 + +# Shard configs are structured as pool..shards. +# Each shard config contains a list of servers that make up the shard +# and the database name to use. +[pools.sharded_db.shards.0] +# Array of servers in the shard, each server entry is an array of `[host, port, role]` +servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]] + +# Array of mirrors for the shard, each mirror entry is an array of `[host, port, index of server in servers array]` +# Traffic hitting the server identified by the index will be sent to the mirror. +# mirrors = [["1.2.3.4", 5432, 0], ["1.2.3.4", 5432, 1]] + +# Database name (e.g. "postgres") +database = "shard0" + +[pools.sharded_db.shards.1] +servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]] +database = "shard1" + +[pools.sharded_db.shards.2] +servers = [["127.0.0.1", 5432, "primary" ], ["localhost", 5432, "replica" ]] +database = "shard2" + + +[pools.simple_db] +pool_mode = "session" +default_role = "primary" +query_parser_enabled = true +primary_reads_enabled = true +sharding_function = "pg_bigint_hash" + +[pools.simple_db.users.0] +username = "simple_user" +password = "simple_user" +pool_size = 5 +min_pool_size = 3 +server_lifetime = 60000 +statement_timeout = 0 + +[pools.simple_db.shards.0] +servers = [ + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ] +] +database = "some_db" diff --git a/src/client.rs b/src/client.rs index c72e9d2a..15295411 100644 --- a/src/client.rs +++ b/src/client.rs @@ -20,7 +20,7 @@ use crate::config::{ use crate::constants::*; use crate::messages::*; use crate::plugins::PluginOutput; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; +use crate::pool::{get_or_create_pool, get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; use crate::stats::{ClientStats, ServerStats}; @@ -557,7 +557,7 @@ where } // Authenticate normal user. else { - let pool = match get_pool(pool_name, username) { + let pool = match get_or_create_pool(pool_name, username).await { Some(pool) => pool, None => { error_response( diff --git a/src/config.rs b/src/config.rs index e56f92b9..c91cae48 100644 --- a/src/config.rs +++ b/src/config.rs @@ -597,6 +597,8 @@ pub struct Pool { #[serde(default = "Pool::default_prepared_statements_cache_size")] pub prepared_statements_cache_size: usize, + #[serde(default = "Pool::default_proxy")] + pub proxy: bool, // Support for query routing based on database activity #[serde(default = "Pool::default_db_activity_based_routing")] pub db_activity_based_routing: bool, @@ -663,6 +665,10 @@ impl Pool { 0 } + pub fn default_proxy() -> bool { + false + } + pub fn default_db_activity_based_routing() -> bool { false } @@ -811,6 +817,7 @@ impl Default for Pool { cleanup_server_connections: true, log_client_parameter_status_changes: false, prepared_statements_cache_size: Self::default_prepared_statements_cache_size(), + proxy: Self::default_proxy(), db_activity_based_routing: Self::default_db_activity_based_routing(), db_activity_init_delay: Self::default_db_activity_init_delay(), db_activity_ttl: Self::default_db_activity_ttl(), @@ -1290,6 +1297,11 @@ impl Config { pool_name, pool_config.pool_mode.to_string() ); + info!( + "[pool: {}] Proxy mode: {}", + pool_name, + pool_config.proxy.to_string() + ); info!( "[pool: {}] Load Balancing mode: {:?}", pool_name, pool_config.load_balancing_mode diff --git a/src/pool.rs b/src/pool.rs index 7ecf24c1..a7d484d2 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -226,6 +226,9 @@ pub struct PoolSettings { pub auth_query_user: Option, pub auth_query_password: Option, + // Proxy + pub proxy: bool, + /// Plugins pub plugins: Option, } @@ -260,6 +263,7 @@ impl Default for PoolSettings { auth_query: None, auth_query_user: None, auth_query_password: None, + proxy: false, plugins: None, } } @@ -318,300 +322,309 @@ impl ConnectionPool { for (pool_name, pool_config) in &config.pools { let new_pool_hash_value = pool_config.hash_value(); - // There is one pool per database/user pair. - for user in pool_config.users.values() { - let old_pool_ref = get_pool(pool_name, &user.username); - let identifier = PoolIdentifier::new(pool_name, &user.username); - - if let Some(pool) = old_pool_ref { - // If the pool hasn't changed, get existing reference and insert it into the new_pools. - // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). - if pool.config_hash == new_pool_hash_value { - info!( - "[pool: {}][user: {}] has not changed", - pool_name, user.username - ); - new_pools.insert(identifier.clone(), pool.clone()); - continue; + let is_proxy: bool = pool_config.proxy; + if !is_proxy { + // There is one pool per database/user pair. + for user in pool_config.users.values() { + let old_pool_ref = get_pool(pool_name, &user.username); + let identifier = PoolIdentifier::new(pool_name, &user.username); + + if let Some(pool) = old_pool_ref { + // If the pool hasn't changed, get existing reference and insert it into the new_pools. + // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). + if pool.config_hash == new_pool_hash_value { + info!( + "[pool: {}][user: {}] has not changed", + pool_name, user.username + ); + new_pools.insert(identifier.clone(), pool.clone()); + continue; + } } - } - info!( - "[pool: {}][user: {}] creating new pool", - pool_name, user.username - ); + info!( + "[pool: {}][user: {}] creating new pool", + pool_name, user.username + ); - let mut shards = Vec::new(); - let mut addresses = Vec::new(); - let mut banlist = Vec::new(); - let mut shard_ids = pool_config - .shards - .clone() - .into_keys() - .collect::>(); - - // Sort by shard number to ensure consistency. - shard_ids.sort_by_key(|k| k.parse::().unwrap()); - let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); - - for shard_idx in &shard_ids { - let shard = &pool_config.shards[shard_idx]; - let mut pools = Vec::new(); - let mut servers = Vec::new(); - let mut replica_number = 0; - - // Load Mirror settings - for (address_index, server) in shard.servers.iter().enumerate() { - let mut mirror_addresses = vec![]; - if let Some(mirror_settings_vec) = &shard.mirrors { - for (mirror_idx, mirror_settings) in - mirror_settings_vec.iter().enumerate() - { - if mirror_settings.mirroring_target_index != address_index { - continue; + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = pool_config + .shards + .clone() + .into_keys() + .collect::>(); + + // Sort by shard number to ensure consistency. + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); + + for shard_idx in &shard_ids { + let shard = &pool_config.shards[shard_idx]; + let mut pools = Vec::new(); + let mut servers = Vec::new(); + let mut replica_number = 0; + + // Load Mirror settings + for (address_index, server) in shard.servers.iter().enumerate() { + let mut mirror_addresses = vec![]; + if let Some(mirror_settings_vec) = &shard.mirrors { + for (mirror_idx, mirror_settings) in + mirror_settings_vec.iter().enumerate() + { + if mirror_settings.mirroring_target_index != address_index { + continue; + } + mirror_addresses.push(Address { + id: address_id, + database: shard.database.clone(), + host: mirror_settings.host.clone(), + port: mirror_settings.port, + role: server.role, + address_index: mirror_idx, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: vec![], + stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), + }); + address_id += 1; } - mirror_addresses.push(Address { - id: address_id, - database: shard.database.clone(), - host: mirror_settings.host.clone(), - port: mirror_settings.port, - role: server.role, - address_index: mirror_idx, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: vec![], - stats: Arc::new(AddressStats::default()), - error_count: Arc::new(AtomicU64::new(0)), - }); - address_id += 1; } - } - - let address = Address { - id: address_id, - database: shard.database.clone(), - host: server.host.clone(), - port: server.port, - role: server.role, - address_index, - replica_number, - shard: shard_idx.parse::().unwrap(), - username: user.username.clone(), - pool_name: pool_name.clone(), - mirrors: mirror_addresses, - stats: Arc::new(AddressStats::default()), - error_count: Arc::new(AtomicU64::new(0)), - }; - address_id += 1; - - if server.role == Role::Replica { - replica_number += 1; - } + let address = Address { + id: address_id, + database: shard.database.clone(), + host: server.host.clone(), + port: server.port, + role: server.role, + address_index, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.clone(), + mirrors: mirror_addresses, + stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), + }; + + address_id += 1; + + if server.role == Role::Replica { + replica_number += 1; + } - // We assume every server in the pool share user/passwords - let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); - - if let Some(apt) = &auth_passthrough { - match apt.fetch_hash(&address).await { - Ok(ok) => { - if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) - { - if ok != *pool_auth_hash_value { - warn!( - "Hash is not the same across shards \ - of the same pool, client auth will \ - be done using last obtained hash. \ - Server: {}:{}, Database: {}", - server.host, server.port, shard.database, - ); + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = + *(pool_auth_hash.read()) + { + if ok != *pool_auth_hash_value { + warn!( + "Hash is not the same across shards \ + of the same pool, client auth will \ + be done using last obtained hash. \ + Server: {}:{}, Database: {}", + server.host, server.port, shard.database, + ); + } } - } - debug!("Hash obtained for {:?}", address); + debug!("Hash obtained for {:?}", address); - { - let mut pool_auth_hash = pool_auth_hash.write(); - *pool_auth_hash = Some(ok.clone()); + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } } + Err(err) => warn!( + "Could not obtain password hashes \ + using auth_query config, ignoring. \ + Error: {:?}", + err, + ), } - Err(err) => warn!( - "Could not obtain password hashes \ - using auth_query config, ignoring. \ - Error: {:?}", - err, - ), } - } - let manager = ServerPool::new( - address.clone(), - user.clone(), - &shard.database, - client_server_map.clone(), - pool_auth_hash.clone(), - match pool_config.plugins { - Some(ref plugins) => Some(plugins.clone()), - None => config.plugins.clone(), - }, - pool_config.cleanup_server_connections, - pool_config.log_client_parameter_status_changes, - pool_config.prepared_statements_cache_size, - ); - - let connect_timeout = match user.connect_timeout { - Some(connect_timeout) => connect_timeout, - None => match pool_config.connect_timeout { + let manager = ServerPool::new( + address.clone(), + user.clone(), + &shard.database, + client_server_map.clone(), + pool_auth_hash.clone(), + match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, + pool_config.cleanup_server_connections, + pool_config.log_client_parameter_status_changes, + pool_config.prepared_statements_cache_size, + ); + + let connect_timeout = match user.connect_timeout { Some(connect_timeout) => connect_timeout, - None => config.general.connect_timeout, - }, - }; + None => match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => config.general.connect_timeout, + }, + }; - let idle_timeout = match user.idle_timeout { - Some(idle_timeout) => idle_timeout, - None => match pool_config.idle_timeout { + let idle_timeout = match user.idle_timeout { Some(idle_timeout) => idle_timeout, - None => config.general.idle_timeout, - }, - }; + None => match pool_config.idle_timeout { + Some(idle_timeout) => idle_timeout, + None => config.general.idle_timeout, + }, + }; - let server_lifetime = match user.server_lifetime { - Some(server_lifetime) => server_lifetime, - None => match pool_config.server_lifetime { + let server_lifetime = match user.server_lifetime { Some(server_lifetime) => server_lifetime, - None => config.general.server_lifetime, - }, - }; - - let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE] - .iter() - .min() - .unwrap(); + None => match pool_config.server_lifetime { + Some(server_lifetime) => server_lifetime, + None => config.general.server_lifetime, + }, + }; + + let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE] + .iter() + .min() + .unwrap(); + + let queue_strategy = match config.general.server_round_robin { + true => QueueStrategy::Fifo, + false => QueueStrategy::Lifo, + }; + + debug!( + "[pool: {}][user: {}] Pool reaper rate: {}ms", + pool_name, user.username, reaper_rate + ); + + let pool = Pool::builder() + .max_size(user.pool_size) + .min_idle(user.min_pool_size) + .connection_timeout(std::time::Duration::from_millis( + connect_timeout, + )) + .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) + .max_lifetime(Some(std::time::Duration::from_millis( + server_lifetime, + ))) + .reaper_rate(std::time::Duration::from_millis(reaper_rate)) + .queue_strategy(queue_strategy) + .test_on_check_out(false); + + let pool = if config.general.validate_config { + pool.build(manager).await? + } else { + pool.build_unchecked(manager) + }; + + pools.push(pool); + servers.push(address); + } - let queue_strategy = match config.general.server_round_robin { - true => QueueStrategy::Fifo, - false => QueueStrategy::Lifo, - }; + shards.push(pools); + addresses.push(servers); + banlist.push(HashMap::new()); + } - debug!( - "[pool: {}][user: {}] Pool reaper rate: {}ms", - pool_name, user.username, reaper_rate + assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username ); - - let pool = Pool::builder() - .max_size(user.pool_size) - .min_idle(user.min_pool_size) - .connection_timeout(std::time::Duration::from_millis(connect_timeout)) - .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) - .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) - .reaper_rate(std::time::Duration::from_millis(reaper_rate)) - .queue_strategy(queue_strategy) - .test_on_check_out(false); - - let pool = if config.general.validate_config { - pool.build(manager).await? - } else { - pool.build_unchecked(manager) - }; - - pools.push(pool); - servers.push(address); } - shards.push(pools); - addresses.push(servers); - banlist.push(HashMap::new()); - } + let pool = ConnectionPool { + databases: Arc::new(shards), + addresses: Arc::new(addresses), + banlist: Arc::new(RwLock::new(banlist)), + config_hash: new_pool_hash_value, + original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), + auth_hash: pool_auth_hash, + settings: Arc::new(PoolSettings { + pool_mode: match user.pool_mode { + Some(pool_mode) => pool_mode, + None => pool_config.pool_mode, + }, + load_balancing_mode: pool_config.load_balancing_mode, + checkout_failure_limit: pool_config.checkout_failure_limit, + // shards: pool_config.shards.clone(), + shards: shard_ids.len(), + user: user.clone(), + db: pool_name.clone(), + default_role: match pool_config.default_role.as_str() { + "any" => None, + "replica" => Some(Role::Replica), + "primary" => Some(Role::Primary), + _ => unreachable!(), + }, + query_parser_enabled: pool_config.query_parser_enabled, + query_parser_max_length: pool_config.query_parser_max_length, + query_parser_read_write_splitting: pool_config + .query_parser_read_write_splitting, + primary_reads_enabled: pool_config.primary_reads_enabled, + sharding_function: pool_config.sharding_function, + db_activity_based_routing: pool_config.db_activity_based_routing, + db_activity_init_delay: pool_config.db_activity_init_delay, + db_activity_ttl: pool_config.db_activity_ttl, + table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl, + automatic_sharding_key: pool_config.automatic_sharding_key.clone(), + healthcheck_delay: config.general.healthcheck_delay, + healthcheck_timeout: config.general.healthcheck_timeout, + ban_time: config.general.ban_time, + sharding_key_regex: pool_config + .sharding_key_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + shard_id_regex: pool_config + .shard_id_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + default_shard: pool_config.default_shard, + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), + proxy: pool_config.proxy.clone(), + plugins: match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, + }), + validated: Arc::new(AtomicBool::new(false)), + paused: Arc::new(AtomicBool::new(false)), + paused_waiter: Arc::new(Notify::new()), + prepared_statement_cache: match pool_config.prepared_statements_cache_size { + 0 => None, + _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new( + pool_config.prepared_statements_cache_size, + )))), + }, + }; - assert_eq!(shards.len(), addresses.len()); - if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { - info!( - "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", - pool_name, user.username - ); - } + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + // Do this async and somewhere else, we don't have to wait here. + if config.general.validate_config { + let validate_pool = pool.clone(); + tokio::task::spawn(async move { + let _ = validate_pool.validate().await; + }); + } - let pool = ConnectionPool { - databases: Arc::new(shards), - addresses: Arc::new(addresses), - banlist: Arc::new(RwLock::new(banlist)), - config_hash: new_pool_hash_value, - original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), - auth_hash: pool_auth_hash, - settings: Arc::new(PoolSettings { - pool_mode: match user.pool_mode { - Some(pool_mode) => pool_mode, - None => pool_config.pool_mode, - }, - load_balancing_mode: pool_config.load_balancing_mode, - checkout_failure_limit: pool_config.checkout_failure_limit, - // shards: pool_config.shards.clone(), - shards: shard_ids.len(), - user: user.clone(), - db: pool_name.clone(), - default_role: match pool_config.default_role.as_str() { - "any" => None, - "replica" => Some(Role::Replica), - "primary" => Some(Role::Primary), - _ => unreachable!(), - }, - query_parser_enabled: pool_config.query_parser_enabled, - query_parser_max_length: pool_config.query_parser_max_length, - query_parser_read_write_splitting: pool_config - .query_parser_read_write_splitting, - primary_reads_enabled: pool_config.primary_reads_enabled, - sharding_function: pool_config.sharding_function, - db_activity_based_routing: pool_config.db_activity_based_routing, - db_activity_init_delay: pool_config.db_activity_init_delay, - db_activity_ttl: pool_config.db_activity_ttl, - table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl, - automatic_sharding_key: pool_config.automatic_sharding_key.clone(), - healthcheck_delay: config.general.healthcheck_delay, - healthcheck_timeout: config.general.healthcheck_timeout, - ban_time: config.general.ban_time, - sharding_key_regex: pool_config - .sharding_key_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - shard_id_regex: pool_config - .shard_id_regex - .clone() - .map(|regex| Regex::new(regex.as_str()).unwrap()), - regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), - default_shard: pool_config.default_shard, - auth_query: pool_config.auth_query.clone(), - auth_query_user: pool_config.auth_query_user.clone(), - auth_query_password: pool_config.auth_query_password.clone(), - plugins: match pool_config.plugins { - Some(ref plugins) => Some(plugins.clone()), - None => config.plugins.clone(), - }, - }), - validated: Arc::new(AtomicBool::new(false)), - paused: Arc::new(AtomicBool::new(false)), - paused_waiter: Arc::new(Notify::new()), - prepared_statement_cache: match pool_config.prepared_statements_cache_size { - 0 => None, - _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new( - pool_config.prepared_statements_cache_size, - )))), - }, - }; - - // Connect to the servers to make sure pool configuration is valid - // before setting it globally. - // Do this async and somewhere else, we don't have to wait here. - if config.general.validate_config { - let validate_pool = pool.clone(); - tokio::task::spawn(async move { - let _ = validate_pool.validate().await; - }); + // There is one pool per database/user pair. + new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); } - - // There is one pool per database/user pair. - new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); } } @@ -1247,7 +1260,340 @@ pub fn get_pool(db: &str, user: &str) -> Option { .cloned() } +pub async fn get_or_create_pool(db: &str, user: &str) -> Option { + let guard = POOLS.load(); + let mut pool = guard.get(&PoolIdentifier::new(db, user)).cloned(); + + if pool.is_none() { + let config = get_config(); + + let pool_config = config.pools.get(db); + + if pool_config.is_some() && pool_config?.proxy { + info!("Using existing pool {}, Proxy is enabled", db); + + pool = match create_pool_for_proxy(db, user).await { + Ok(pool) => Option::from(pool.unwrap()), + Err(_err) => None, + }; + + info!("Created a new pool {:?}", pool); + } + } + + return pool; +} + /// Get a pointer to all configured pools. pub fn get_all_pools() -> HashMap { (*(*POOLS.load())).clone() } + +async fn create_pool_for_proxy( + db: &str, + psql_user: &str, +) -> Result, Error> { + let config = get_config(); + let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); + + let mut address_id: usize = 0; + + let pool_config_opt = config.pools.get_key_value(db); + + let pool_name: &str = pool_config_opt.unwrap().0; + let pool_config = pool_config_opt.unwrap().1; + + info!("Created a new pool {:?}", pool_config); + + let mut user: User = pool_config.users.get("0").unwrap().clone(); + user.username = psql_user.to_string(); + + let mut shards = Vec::new(); + let mut addresses = Vec::new(); + let mut banlist = Vec::new(); + let mut shard_ids = pool_config + .shards + .clone() + .into_keys() + .collect::>(); + + // Sort by shard number to ensure consistency. + shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); + + for shard_idx in &shard_ids { + let shard = &pool_config.shards[shard_idx]; + let mut pools = Vec::new(); + let mut servers = Vec::new(); + let mut replica_number = 0; + + // Load Mirror settings + for (address_index, server) in shard.servers.iter().enumerate() { + let mut mirror_addresses = vec![]; + if let Some(mirror_settings_vec) = &shard.mirrors { + for (mirror_idx, mirror_settings) in mirror_settings_vec.iter().enumerate() { + if mirror_settings.mirroring_target_index != address_index { + continue; + } + mirror_addresses.push(Address { + id: address_id, + database: db.to_string(), + host: mirror_settings.host.clone(), + port: mirror_settings.port, + role: server.role, + address_index: mirror_idx, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.to_string(), + mirrors: vec![], + stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), + }); + address_id += 1; + } + } + + let address = Address { + id: address_id, + database: psql_user.to_string(), + host: server.host.clone(), + port: server.port, + role: server.role, + address_index, + replica_number, + shard: shard_idx.parse::().unwrap(), + username: user.username.clone(), + pool_name: pool_name.to_string(), + mirrors: mirror_addresses, + stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), + }; + + address_id += 1; + + if server.role == Role::Replica { + replica_number += 1; + } + + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { + if ok != *pool_auth_hash_value { + warn!( + "Hash is not the same across shards \ + of the same pool, client auth will \ + be done using last obtained hash. \ + Server: {}:{}, Database: {}", + server.host, server.port, shard.database, + ); + } + } + + debug!("Hash obtained for {:?}", address); + + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } + } + Err(err) => warn!( + "Could not obtain password hashes \ + using auth_query config, ignoring. \ + Error: {:?}", + err, + ), + } + } + + let manager = ServerPool::new( + address.clone(), + user.clone(), + psql_user, + client_server_map.clone(), + pool_auth_hash.clone(), + match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, + pool_config.cleanup_server_connections, + pool_config.log_client_parameter_status_changes, + pool_config.prepared_statements_cache_size, + ); + + let connect_timeout = match user.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => match pool_config.connect_timeout { + Some(connect_timeout) => connect_timeout, + None => config.general.connect_timeout, + }, + }; + + let idle_timeout = match user.idle_timeout { + Some(idle_timeout) => idle_timeout, + None => match pool_config.idle_timeout { + Some(idle_timeout) => idle_timeout, + None => config.general.idle_timeout, + }, + }; + + let server_lifetime = match user.server_lifetime { + Some(server_lifetime) => server_lifetime, + None => match pool_config.server_lifetime { + Some(server_lifetime) => server_lifetime, + None => config.general.server_lifetime, + }, + }; + + let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE] + .iter() + .min() + .unwrap(); + + let queue_strategy = match config.general.server_round_robin { + true => QueueStrategy::Fifo, + false => QueueStrategy::Lifo, + }; + + debug!( + "[pool: {}][user: {}] Pool reaper rate: {}ms", + pool_name, user.username, reaper_rate + ); + + let pool = Pool::builder() + .max_size(user.pool_size) + .min_idle(user.min_pool_size) + .connection_timeout(std::time::Duration::from_millis(connect_timeout)) + .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) + .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime))) + .reaper_rate(std::time::Duration::from_millis(reaper_rate)) + .queue_strategy(queue_strategy) + .test_on_check_out(false); + + let pool = if config.general.validate_config { + pool.build(manager).await? + } else { + pool.build_unchecked(manager) + }; + + pools.push(pool); + servers.push(address); + } + + shards.push(pools); + addresses.push(servers); + banlist.push(HashMap::new()); + } + assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username + ); + } + let new_pool_hash_value = pool_config.hash_value(); + + let pool = ConnectionPool { + databases: Arc::new(shards), + addresses: Arc::new(addresses), + banlist: Arc::new(RwLock::new(banlist)), + config_hash: new_pool_hash_value, + original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), + auth_hash: pool_auth_hash, + settings: Arc::new(PoolSettings { + pool_mode: match user.pool_mode { + Some(pool_mode) => pool_mode, + None => pool_config.pool_mode, + }, + load_balancing_mode: pool_config.load_balancing_mode, + checkout_failure_limit: pool_config.checkout_failure_limit, + // shards: pool_config.shards.clone(), + shards: shard_ids.len(), + user: user.clone(), + db: pool_name.to_string(), + default_role: match pool_config.default_role.as_str() { + "any" => None, + "replica" => Some(Role::Replica), + "primary" => Some(Role::Primary), + _ => unreachable!(), + }, + query_parser_enabled: pool_config.query_parser_enabled, + query_parser_max_length: pool_config.query_parser_max_length, + query_parser_read_write_splitting: pool_config.query_parser_read_write_splitting, + primary_reads_enabled: pool_config.primary_reads_enabled, + sharding_function: pool_config.sharding_function, + db_activity_based_routing: pool_config.db_activity_based_routing, + db_activity_init_delay: pool_config.db_activity_init_delay, + db_activity_ttl: pool_config.db_activity_ttl, + table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl, + automatic_sharding_key: pool_config.automatic_sharding_key.clone(), + healthcheck_delay: config.general.healthcheck_delay, + healthcheck_timeout: config.general.healthcheck_timeout, + ban_time: config.general.ban_time, + sharding_key_regex: pool_config + .sharding_key_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + shard_id_regex: pool_config + .shard_id_regex + .clone() + .map(|regex| Regex::new(regex.as_str()).unwrap()), + regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + default_shard: pool_config.default_shard, + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), + proxy: pool_config.proxy.clone(), + plugins: match pool_config.plugins { + Some(ref plugins) => Some(plugins.clone()), + None => config.plugins.clone(), + }, + }), + validated: Arc::new(AtomicBool::new(false)), + paused: Arc::new(AtomicBool::new(false)), + paused_waiter: Arc::new(Notify::new()), + prepared_statement_cache: match pool_config.prepared_statements_cache_size { + 0 => None, + _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new( + pool_config.prepared_statements_cache_size, + )))), + }, + }; + // Connect to the servers to make sure pool configuration is valid + // before setting it globally. + // Do this async and somewhere else, we don't have to wait here. + if config.general.validate_config { + let validate_pool = pool.clone(); + tokio::task::spawn(async move { + let _ = validate_pool.validate().await; + }); + } + + add_connection_pool(PoolIdentifier::new(pool_name, &user.username), pool.clone()); + + Ok(Option::Some::(pool)) +} + +fn add_connection_pool(id: PoolIdentifier, pool: ConnectionPool) { + loop { + let old_map = POOLS.load(); // Load current Arc> + let mut new_map = (**old_map).clone(); // Clone and modify + new_map.insert(id.clone(), pool.clone()); + + let new_arc = Arc::new(new_map); // Wrap in Arc + + // Attempt atomic swap + let previous = POOLS.compare_and_swap(&old_map, new_arc); + + // Check if swap was successful + if Arc::ptr_eq(&previous, &old_map) { + break; // Success, exit loop + } + + // Otherwise, retry (another thread modified it) + } +}