From cac21b52019c78a7ab4f236d41f52df07a303b65 Mon Sep 17 00:00:00 2001
From: Ravi Soni <rv.soni@hotmail.com>
Date: Sat, 9 Nov 2024 21:00:27 +0530
Subject: [PATCH 1/4] Implemented PGBouncer like database proxy functionality.

---
 pgcat.proxy.toml | 33 +++++++++++++++++++++++++++++++++
 1 file changed, 33 insertions(+)
 create mode 100644 pgcat.proxy.toml

diff --git a/pgcat.proxy.toml b/pgcat.proxy.toml
new file mode 100644
index 00000000..71ef33d3
--- /dev/null
+++ b/pgcat.proxy.toml
@@ -0,0 +1,33 @@
+# This is an example of the most basic config
+# that will mimic what PgBouncer does in transaction mode with one server.
+
+[general]
+
+log_level = "DEBUG"
+
+host = "0.0.0.0"
+port = 6433
+admin_username = "pgcat"
+admin_password = "pgcat"
+
+[pools.pgml]
+auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
+
+# Be sure to grant this user LOGIN on Postgres
+auth_query_user = "postgres"
+auth_query_password = "postgres"
+
+proxy = true
+
+[pools.pgml.users.0]
+username = "postgres"
+#password = "postgres"
+pool_size = 10
+min_pool_size = 1
+pool_mode = "session"
+
+[pools.pgml.shards.0]
+servers = [
+  ["localhost", 5432, "primary"]
+]
+database = "postgres"

From 5a3f739cba6a00576f0765320a82c0b360f619d7 Mon Sep 17 00:00:00 2001
From: Ravi Soni <rv.soni@hotmail.com>
Date: Sat, 9 Nov 2024 21:00:34 +0530
Subject: [PATCH 2/4] Implemented PGBouncer like database proxy functionality.

---
 src/client.rs |   4 +-
 src/config.rs |  11 +
 src/pool.rs   | 838 ++++++++++++++++++++++++++++++++++----------------
 3 files changed, 591 insertions(+), 262 deletions(-)

diff --git a/src/client.rs b/src/client.rs
index c226436e..7682aab8 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -20,7 +20,7 @@ use crate::config::{
 use crate::constants::*;
 use crate::messages::*;
 use crate::plugins::PluginOutput;
-use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
+use crate::pool::{get_pool, get_or_create_pool, ClientServerMap, ConnectionPool};
 use crate::query_router::{Command, QueryRouter};
 use crate::server::{Server, ServerParameters};
 use crate::stats::{ClientStats, ServerStats};
@@ -557,7 +557,7 @@ where
         }
         // Authenticate normal user.
         else {
-            let pool = match get_pool(pool_name, username) {
+            let pool = match get_or_create_pool(pool_name, username).await {
                 Some(pool) => pool,
                 None => {
                     error_response(
diff --git a/src/config.rs b/src/config.rs
index b0d98fb5..7963a9a0 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -589,6 +589,10 @@ pub struct Pool {
     #[serde(default = "Pool::default_prepared_statements_cache_size")]
     pub prepared_statements_cache_size: usize,
 
+
+    #[serde(default = "Pool::default_proxy")]
+    pub proxy: bool,
+
     pub plugins: Option<Plugins>,
     pub shards: BTreeMap<String, Shard>,
     pub users: BTreeMap<String, User>,
@@ -642,6 +646,8 @@ impl Pool {
         0
     }
 
+    pub fn default_proxy() -> bool {false}
+
     pub fn validate(&mut self) -> Result<(), Error> {
         match self.default_role.as_ref() {
             "any" => (),
@@ -753,6 +759,7 @@ impl Default for Pool {
             cleanup_server_connections: true,
             log_client_parameter_status_changes: false,
             prepared_statements_cache_size: Self::default_prepared_statements_cache_size(),
+            proxy: Self::default_proxy(),
             plugins: None,
             shards: BTreeMap::from([(String::from("1"), Shard::default())]),
             users: BTreeMap::default(),
@@ -1228,6 +1235,10 @@ impl Config {
                 pool_name,
                 pool_config.pool_mode.to_string()
             );
+            info!("[pool: {}] Proxy mode: {}",
+                pool_name,
+                pool_config.proxy.to_string()
+            );
             info!(
                 "[pool: {}] Load Balancing mode: {:?}",
                 pool_name, pool_config.load_balancing_mode
diff --git a/src/pool.rs b/src/pool.rs
index 7915a0a4..7616faad 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -206,6 +206,9 @@ pub struct PoolSettings {
     pub auth_query_user: Option<String>,
     pub auth_query_password: Option<String>,
 
+    // Proxy
+    pub proxy: bool,
+
     /// Plugins
     pub plugins: Option<Plugins>,
 }
@@ -235,6 +238,7 @@ impl Default for PoolSettings {
             auth_query: None,
             auth_query_user: None,
             auth_query_password: None,
+            proxy: false,
             plugins: None,
         }
     }
@@ -293,295 +297,301 @@ impl ConnectionPool {
         for (pool_name, pool_config) in &config.pools {
             let new_pool_hash_value = pool_config.hash_value();
 
-            // There is one pool per database/user pair.
-            for user in pool_config.users.values() {
-                let old_pool_ref = get_pool(pool_name, &user.username);
-                let identifier = PoolIdentifier::new(pool_name, &user.username);
+            let is_proxy :bool = pool_config.proxy;
+            if !is_proxy {
 
-                if let Some(pool) = old_pool_ref {
-                    // If the pool hasn't changed, get existing reference and insert it into the new_pools.
-                    // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
-                    if pool.config_hash == new_pool_hash_value {
-                        info!(
-                            "[pool: {}][user: {}] has not changed",
-                            pool_name, user.username
-                        );
-                        new_pools.insert(identifier.clone(), pool.clone());
-                        continue;
+                // There is one pool per database/user pair.
+                for user in pool_config.users.values() {
+                    let old_pool_ref = get_pool(pool_name, &user.username);
+                    let identifier = PoolIdentifier::new(pool_name, &user.username);
+
+                    if let Some(pool) = old_pool_ref {
+                        // If the pool hasn't changed, get existing reference and insert it into the new_pools.
+                        // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
+                        if pool.config_hash == new_pool_hash_value {
+                            info!(
+                                "[pool: {}][user: {}] has not changed",
+                                pool_name, user.username
+                            );
+                            new_pools.insert(identifier.clone(), pool.clone());
+                            continue;
+                        }
                     }
-                }
 
-                info!(
-                    "[pool: {}][user: {}] creating new pool",
-                    pool_name, user.username
-                );
+                    info!(
+                        "[pool: {}][user: {}] creating new pool",
+                        pool_name, user.username
+                    );
 
-                let mut shards = Vec::new();
-                let mut addresses = Vec::new();
-                let mut banlist = Vec::new();
-                let mut shard_ids = pool_config
-                    .shards
-                    .clone()
-                    .into_keys()
-                    .collect::<Vec<String>>();
-
-                // Sort by shard number to ensure consistency.
-                shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
-                let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
-
-                for shard_idx in &shard_ids {
-                    let shard = &pool_config.shards[shard_idx];
-                    let mut pools = Vec::new();
-                    let mut servers = Vec::new();
-                    let mut replica_number = 0;
-
-                    // Load Mirror settings
-                    for (address_index, server) in shard.servers.iter().enumerate() {
-                        let mut mirror_addresses = vec![];
-                        if let Some(mirror_settings_vec) = &shard.mirrors {
-                            for (mirror_idx, mirror_settings) in
-                                mirror_settings_vec.iter().enumerate()
-                            {
-                                if mirror_settings.mirroring_target_index != address_index {
-                                    continue;
+                    let mut shards = Vec::new();
+                    let mut addresses = Vec::new();
+                    let mut banlist = Vec::new();
+                    let mut shard_ids = pool_config
+                        .shards
+                        .clone()
+                        .into_keys()
+                        .collect::<Vec<String>>();
+
+                    // Sort by shard number to ensure consistency.
+                    shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
+                    let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
+
+                    for shard_idx in &shard_ids {
+                        let shard = &pool_config.shards[shard_idx];
+                        let mut pools = Vec::new();
+                        let mut servers = Vec::new();
+                        let mut replica_number = 0;
+
+                        // Load Mirror settings
+                        for (address_index, server) in shard.servers.iter().enumerate() {
+                            let mut mirror_addresses = vec![];
+                            if let Some(mirror_settings_vec) = &shard.mirrors {
+                                for (mirror_idx, mirror_settings) in
+                                    mirror_settings_vec.iter().enumerate()
+                                {
+                                    if mirror_settings.mirroring_target_index != address_index {
+                                        continue;
+                                    }
+                                    mirror_addresses.push(Address {
+                                        id: address_id,
+                                        database: shard.database.clone(),
+                                        host: mirror_settings.host.clone(),
+                                        port: mirror_settings.port,
+                                        role: server.role,
+                                        address_index: mirror_idx,
+                                        replica_number,
+                                        shard: shard_idx.parse::<usize>().unwrap(),
+                                        username: user.username.clone(),
+                                        pool_name: pool_name.clone(),
+                                        mirrors: vec![],
+                                        stats: Arc::new(AddressStats::default()),
+                                        error_count: Arc::new(AtomicU64::new(0)),
+                                    });
+                                    address_id += 1;
                                 }
-                                mirror_addresses.push(Address {
-                                    id: address_id,
-                                    database: shard.database.clone(),
-                                    host: mirror_settings.host.clone(),
-                                    port: mirror_settings.port,
-                                    role: server.role,
-                                    address_index: mirror_idx,
-                                    replica_number,
-                                    shard: shard_idx.parse::<usize>().unwrap(),
-                                    username: user.username.clone(),
-                                    pool_name: pool_name.clone(),
-                                    mirrors: vec![],
-                                    stats: Arc::new(AddressStats::default()),
-                                    error_count: Arc::new(AtomicU64::new(0)),
-                                });
-                                address_id += 1;
                             }
-                        }
-
-                        let address = Address {
-                            id: address_id,
-                            database: shard.database.clone(),
-                            host: server.host.clone(),
-                            port: server.port,
-                            role: server.role,
-                            address_index,
-                            replica_number,
-                            shard: shard_idx.parse::<usize>().unwrap(),
-                            username: user.username.clone(),
-                            pool_name: pool_name.clone(),
-                            mirrors: mirror_addresses,
-                            stats: Arc::new(AddressStats::default()),
-                            error_count: Arc::new(AtomicU64::new(0)),
-                        };
 
-                        address_id += 1;
-
-                        if server.role == Role::Replica {
-                            replica_number += 1;
-                        }
+                            let address = Address {
+                                id: address_id,
+                                database: shard.database.clone(),
+                                host: server.host.clone(),
+                                port: server.port,
+                                role: server.role,
+                                address_index,
+                                replica_number,
+                                shard: shard_idx.parse::<usize>().unwrap(),
+                                username: user.username.clone(),
+                                pool_name: pool_name.clone(),
+                                mirrors: mirror_addresses,
+                                stats: Arc::new(AddressStats::default()),
+                                error_count: Arc::new(AtomicU64::new(0)),
+                            };
+
+                            address_id += 1;
+
+                            if server.role == Role::Replica {
+                                replica_number += 1;
+                            }
 
-                        // We assume every server in the pool share user/passwords
-                        let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
-
-                        if let Some(apt) = &auth_passthrough {
-                            match apt.fetch_hash(&address).await {
-                                Ok(ok) => {
-                                    if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
-                                    {
-                                        if ok != *pool_auth_hash_value {
-                                            warn!(
-                                                "Hash is not the same across shards \
-                                                of the same pool, client auth will \
-                                                be done using last obtained hash. \
-                                                Server: {}:{}, Database: {}",
-                                                server.host, server.port, shard.database,
-                                            );
+                            // We assume every server in the pool share user/passwords
+                            let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
+
+                            if let Some(apt) = &auth_passthrough {
+                                match apt.fetch_hash(&address).await {
+                                    Ok(ok) => {
+                                        if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
+                                        {
+                                            if ok != *pool_auth_hash_value {
+                                                warn!(
+                                                    "Hash is not the same across shards \
+                                                    of the same pool, client auth will \
+                                                    be done using last obtained hash. \
+                                                    Server: {}:{}, Database: {}",
+                                                    server.host, server.port, shard.database,
+                                                );
+                                            }
                                         }
-                                    }
 
-                                    debug!("Hash obtained for {:?}", address);
+                                        debug!("Hash obtained for {:?}", address);
 
-                                    {
-                                        let mut pool_auth_hash = pool_auth_hash.write();
-                                        *pool_auth_hash = Some(ok.clone());
+                                        {
+                                            let mut pool_auth_hash = pool_auth_hash.write();
+                                            *pool_auth_hash = Some(ok.clone());
+                                        }
                                     }
+                                    Err(err) => warn!(
+                                        "Could not obtain password hashes \
+                                            using auth_query config, ignoring. \
+                                            Error: {:?}",
+                                        err,
+                                    ),
                                 }
-                                Err(err) => warn!(
-                                    "Could not obtain password hashes \
-                                        using auth_query config, ignoring. \
-                                        Error: {:?}",
-                                    err,
-                                ),
                             }
-                        }
-
-                        let manager = ServerPool::new(
-                            address.clone(),
-                            user.clone(),
-                            &shard.database,
-                            client_server_map.clone(),
-                            pool_auth_hash.clone(),
-                            match pool_config.plugins {
-                                Some(ref plugins) => Some(plugins.clone()),
-                                None => config.plugins.clone(),
-                            },
-                            pool_config.cleanup_server_connections,
-                            pool_config.log_client_parameter_status_changes,
-                            pool_config.prepared_statements_cache_size,
-                        );
 
-                        let connect_timeout = match user.connect_timeout {
-                            Some(connect_timeout) => connect_timeout,
-                            None => match pool_config.connect_timeout {
+                            let manager = ServerPool::new(
+                                address.clone(),
+                                user.clone(),
+                                &shard.database,
+                                client_server_map.clone(),
+                                pool_auth_hash.clone(),
+                                match pool_config.plugins {
+                                    Some(ref plugins) => Some(plugins.clone()),
+                                    None => config.plugins.clone(),
+                                },
+                                pool_config.cleanup_server_connections,
+                                pool_config.log_client_parameter_status_changes,
+                                pool_config.prepared_statements_cache_size,
+                            );
+
+                            let connect_timeout = match user.connect_timeout {
                                 Some(connect_timeout) => connect_timeout,
-                                None => config.general.connect_timeout,
-                            },
-                        };
+                                None => match pool_config.connect_timeout {
+                                    Some(connect_timeout) => connect_timeout,
+                                    None => config.general.connect_timeout,
+                                },
+                            };
 
-                        let idle_timeout = match user.idle_timeout {
-                            Some(idle_timeout) => idle_timeout,
-                            None => match pool_config.idle_timeout {
+                            let idle_timeout = match user.idle_timeout {
                                 Some(idle_timeout) => idle_timeout,
-                                None => config.general.idle_timeout,
-                            },
-                        };
+                                None => match pool_config.idle_timeout {
+                                    Some(idle_timeout) => idle_timeout,
+                                    None => config.general.idle_timeout,
+                                },
+                            };
 
-                        let server_lifetime = match user.server_lifetime {
-                            Some(server_lifetime) => server_lifetime,
-                            None => match pool_config.server_lifetime {
+                            let server_lifetime = match user.server_lifetime {
                                 Some(server_lifetime) => server_lifetime,
-                                None => config.general.server_lifetime,
-                            },
-                        };
-
-                        let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
-                            .iter()
-                            .min()
-                            .unwrap();
+                                None => match pool_config.server_lifetime {
+                                    Some(server_lifetime) => server_lifetime,
+                                    None => config.general.server_lifetime,
+                                },
+                            };
+
+                            let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
+                                .iter()
+                                .min()
+                                .unwrap();
+
+                            let queue_strategy = match config.general.server_round_robin {
+                                true => QueueStrategy::Fifo,
+                                false => QueueStrategy::Lifo,
+                            };
+
+                            debug!(
+                                "[pool: {}][user: {}] Pool reaper rate: {}ms",
+                                pool_name, user.username, reaper_rate
+                            );
+
+                            let pool = Pool::builder()
+                                .max_size(user.pool_size)
+                                .min_idle(user.min_pool_size)
+                                .connection_timeout(std::time::Duration::from_millis(connect_timeout))
+                                .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
+                                .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
+                                .reaper_rate(std::time::Duration::from_millis(reaper_rate))
+                                .queue_strategy(queue_strategy)
+                                .test_on_check_out(false);
+
+                            let pool = if config.general.validate_config {
+                                pool.build(manager).await?
+                            } else {
+                                pool.build_unchecked(manager)
+                            };
+
+                            pools.push(pool);
+                            servers.push(address);
+                        }
 
-                        let queue_strategy = match config.general.server_round_robin {
-                            true => QueueStrategy::Fifo,
-                            false => QueueStrategy::Lifo,
-                        };
+                        shards.push(pools);
+                        addresses.push(servers);
+                        banlist.push(HashMap::new());
+                    }
 
-                        debug!(
-                            "[pool: {}][user: {}] Pool reaper rate: {}ms",
-                            pool_name, user.username, reaper_rate
+                    assert_eq!(shards.len(), addresses.len());
+                    if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
+                        info!(
+                            "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
+                            pool_name, user.username
                         );
+                    }
 
-                        let pool = Pool::builder()
-                            .max_size(user.pool_size)
-                            .min_idle(user.min_pool_size)
-                            .connection_timeout(std::time::Duration::from_millis(connect_timeout))
-                            .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
-                            .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
-                            .reaper_rate(std::time::Duration::from_millis(reaper_rate))
-                            .queue_strategy(queue_strategy)
-                            .test_on_check_out(false);
-
-                        let pool = if config.general.validate_config {
-                            pool.build(manager).await?
-                        } else {
-                            pool.build_unchecked(manager)
-                        };
+                    let pool = ConnectionPool {
+                        databases: Arc::new(shards),
+                        addresses: Arc::new(addresses),
+                        banlist: Arc::new(RwLock::new(banlist)),
+                        config_hash: new_pool_hash_value,
+                        original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
+                        auth_hash: pool_auth_hash,
+                        settings: Arc::new(PoolSettings {
+                            pool_mode: match user.pool_mode {
+                                Some(pool_mode) => pool_mode,
+                                None => pool_config.pool_mode,
+                            },
+                            load_balancing_mode: pool_config.load_balancing_mode,
+                            // shards: pool_config.shards.clone(),
+                            shards: shard_ids.len(),
+                            user: user.clone(),
+                            db: pool_name.clone(),
+                            default_role: match pool_config.default_role.as_str() {
+                                "any" => None,
+                                "replica" => Some(Role::Replica),
+                                "primary" => Some(Role::Primary),
+                                _ => unreachable!(),
+                            },
+                            query_parser_enabled: pool_config.query_parser_enabled,
+                            query_parser_max_length: pool_config.query_parser_max_length,
+                            query_parser_read_write_splitting: pool_config
+                                .query_parser_read_write_splitting,
+                            primary_reads_enabled: pool_config.primary_reads_enabled,
+                            sharding_function: pool_config.sharding_function,
+                            automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
+                            healthcheck_delay: config.general.healthcheck_delay,
+                            healthcheck_timeout: config.general.healthcheck_timeout,
+                            ban_time: config.general.ban_time,
+                            sharding_key_regex: pool_config
+                                .sharding_key_regex
+                                .clone()
+                                .map(|regex| Regex::new(regex.as_str()).unwrap()),
+                            shard_id_regex: pool_config
+                                .shard_id_regex
+                                .clone()
+                                .map(|regex| Regex::new(regex.as_str()).unwrap()),
+                            regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
+                            default_shard: pool_config.default_shard,
+                            auth_query: pool_config.auth_query.clone(),
+                            auth_query_user: pool_config.auth_query_user.clone(),
+                            auth_query_password: pool_config.auth_query_password.clone(),
+                            proxy: pool_config.proxy.clone(),
+                            plugins: match pool_config.plugins {
+                                Some(ref plugins) => Some(plugins.clone()),
+                                None => config.plugins.clone(),
+                            },
+                        }),
+                        validated: Arc::new(AtomicBool::new(false)),
+                        paused: Arc::new(AtomicBool::new(false)),
+                        paused_waiter: Arc::new(Notify::new()),
+                        prepared_statement_cache: match pool_config.prepared_statements_cache_size {
+                            0 => None,
+                            _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
+                                pool_config.prepared_statements_cache_size,
+                            )))),
+                        },
+                    };
 
-                        pools.push(pool);
-                        servers.push(address);
+                    // Connect to the servers to make sure pool configuration is valid
+                    // before setting it globally.
+                    // Do this async and somewhere else, we don't have to wait here.
+                    if config.general.validate_config {
+                        let validate_pool = pool.clone();
+                        tokio::task::spawn(async move {
+                            let _ = validate_pool.validate().await;
+                        });
                     }
 
-                    shards.push(pools);
-                    addresses.push(servers);
-                    banlist.push(HashMap::new());
+                    // There is one pool per database/user pair.
+                    new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
                 }
 
-                assert_eq!(shards.len(), addresses.len());
-                if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
-                    info!(
-                        "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
-                        pool_name, user.username
-                    );
-                }
-
-                let pool = ConnectionPool {
-                    databases: Arc::new(shards),
-                    addresses: Arc::new(addresses),
-                    banlist: Arc::new(RwLock::new(banlist)),
-                    config_hash: new_pool_hash_value,
-                    original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
-                    auth_hash: pool_auth_hash,
-                    settings: Arc::new(PoolSettings {
-                        pool_mode: match user.pool_mode {
-                            Some(pool_mode) => pool_mode,
-                            None => pool_config.pool_mode,
-                        },
-                        load_balancing_mode: pool_config.load_balancing_mode,
-                        // shards: pool_config.shards.clone(),
-                        shards: shard_ids.len(),
-                        user: user.clone(),
-                        db: pool_name.clone(),
-                        default_role: match pool_config.default_role.as_str() {
-                            "any" => None,
-                            "replica" => Some(Role::Replica),
-                            "primary" => Some(Role::Primary),
-                            _ => unreachable!(),
-                        },
-                        query_parser_enabled: pool_config.query_parser_enabled,
-                        query_parser_max_length: pool_config.query_parser_max_length,
-                        query_parser_read_write_splitting: pool_config
-                            .query_parser_read_write_splitting,
-                        primary_reads_enabled: pool_config.primary_reads_enabled,
-                        sharding_function: pool_config.sharding_function,
-                        automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
-                        healthcheck_delay: config.general.healthcheck_delay,
-                        healthcheck_timeout: config.general.healthcheck_timeout,
-                        ban_time: config.general.ban_time,
-                        sharding_key_regex: pool_config
-                            .sharding_key_regex
-                            .clone()
-                            .map(|regex| Regex::new(regex.as_str()).unwrap()),
-                        shard_id_regex: pool_config
-                            .shard_id_regex
-                            .clone()
-                            .map(|regex| Regex::new(regex.as_str()).unwrap()),
-                        regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
-                        default_shard: pool_config.default_shard,
-                        auth_query: pool_config.auth_query.clone(),
-                        auth_query_user: pool_config.auth_query_user.clone(),
-                        auth_query_password: pool_config.auth_query_password.clone(),
-                        plugins: match pool_config.plugins {
-                            Some(ref plugins) => Some(plugins.clone()),
-                            None => config.plugins.clone(),
-                        },
-                    }),
-                    validated: Arc::new(AtomicBool::new(false)),
-                    paused: Arc::new(AtomicBool::new(false)),
-                    paused_waiter: Arc::new(Notify::new()),
-                    prepared_statement_cache: match pool_config.prepared_statements_cache_size {
-                        0 => None,
-                        _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
-                            pool_config.prepared_statements_cache_size,
-                        )))),
-                    },
-                };
-
-                // Connect to the servers to make sure pool configuration is valid
-                // before setting it globally.
-                // Do this async and somewhere else, we don't have to wait here.
-                if config.general.validate_config {
-                    let validate_pool = pool.clone();
-                    tokio::task::spawn(async move {
-                        let _ = validate_pool.validate().await;
-                    });
-                }
-
-                // There is one pool per database/user pair.
-                new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
             }
         }
 
@@ -1217,7 +1227,315 @@ pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
         .cloned()
 }
 
+pub async fn get_or_create_pool(db: &str, user: &str) -> Option<ConnectionPool> {
+    let guard = POOLS.load();
+    let mut pool = guard.get(&PoolIdentifier::new(db, user)).cloned();
+
+    if pool.is_none() {
+        let config = get_config();
+
+        let pool_config = config.pools.get(db);
+
+        if pool_config.is_some() && pool_config?.proxy {
+            info!("Using existing pool {}, Proxy is enabled", db);
+            
+            pool = match create_pool_for_proxy(db, user).await{
+                Ok(pool) => Option::from(pool.unwrap()),
+                Err(err) => None
+            };
+                        
+            info!("Created a new pool {:?}", pool);
+        }
+    }
+
+    return pool;
+}
+
 /// Get a pointer to all configured pools.
 pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
     (*(*POOLS.load())).clone()
 }
+
+async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<ConnectionPool>), Error> {
+    let config = get_config();
+    let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
+
+    let mut new_pools = HashMap::new();
+    let mut address_id: usize = 0;
+
+    let pool_config_opt = config.pools.get_key_value(db);
+
+    let pool_name: &str = pool_config_opt.unwrap().0;
+    let pool_config = pool_config_opt.unwrap().1;
+
+    info!("Created a new pool {:?}", pool_config);
+
+    let mut user: User = pool_config.users.get("0").unwrap().clone();
+    user.username = psql_user.to_string();
+    
+    let mut shards = Vec::new();
+    let mut addresses = Vec::new();
+    let mut banlist = Vec::new();
+    let mut shard_ids = pool_config
+        .shards
+        .clone()
+        .into_keys()
+        .collect::<Vec<String>>();
+
+    // Sort by shard number to ensure consistency.
+    shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
+    let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
+
+    for shard_idx in &shard_ids {
+        let shard = &pool_config.shards[shard_idx];
+        let mut pools = Vec::new();
+        let mut servers = Vec::new();
+        let mut replica_number = 0;
+
+        // Load Mirror settings
+        for (address_index, server) in shard.servers.iter().enumerate() {
+            let mut mirror_addresses = vec![];
+            if let Some(mirror_settings_vec) = &shard.mirrors {
+                for (mirror_idx, mirror_settings) in
+                    mirror_settings_vec.iter().enumerate()
+                {
+                    if mirror_settings.mirroring_target_index != address_index {
+                        continue;
+                    }
+                    mirror_addresses.push(Address {
+                        id: address_id,
+                        database: db.to_string(),
+                        host: mirror_settings.host.clone(),
+                        port: mirror_settings.port,
+                        role: server.role,
+                        address_index: mirror_idx,
+                        replica_number,
+                        shard: shard_idx.parse::<usize>().unwrap(),
+                        username: user.username.clone(),
+                        pool_name: pool_name.to_string(),
+                        mirrors: vec![],
+                        stats: Arc::new(AddressStats::default()),
+                        error_count: Arc::new(AtomicU64::new(0)),
+                    });
+                    address_id += 1;
+                }
+            }
+
+            let address = Address {
+                id: address_id,
+                database: psql_user.to_string(),
+                host: server.host.clone(),
+                port: server.port,
+                role: server.role,
+                address_index,
+                replica_number,
+                shard: shard_idx.parse::<usize>().unwrap(),
+                username: user.username.clone(),
+                pool_name: pool_name.to_string(),
+                mirrors: mirror_addresses,
+                stats: Arc::new(AddressStats::default()),
+                error_count: Arc::new(AtomicU64::new(0)),
+            };
+
+            address_id += 1;
+
+            if server.role == Role::Replica {
+                replica_number += 1;
+            }
+
+            // We assume every server in the pool share user/passwords
+            let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
+
+            if let Some(apt) = &auth_passthrough {
+                match apt.fetch_hash(&address).await {
+                    Ok(ok) => {
+                        
+                        if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
+                        {
+                            if ok != *pool_auth_hash_value {
+                                warn!("Hash is not the same across shards \
+                                    of the same pool, client auth will \
+                                    be done using last obtained hash. \
+                                    Server: {}:{}, Database: {}",
+                                    server.host, server.port, shard.database,);
+                            }
+                        }
+
+                        debug!("Hash obtained for {:?}", address);
+
+                        {
+                            let mut pool_auth_hash = pool_auth_hash.write();
+                            *pool_auth_hash = Some(ok.clone());
+                        }
+                    }
+                    Err(err) => warn!(
+                                    "Could not obtain password hashes \
+                                        using auth_query config, ignoring. \
+                                        Error: {:?}",
+                                    err,
+                                ),
+                }
+            }
+
+            let manager = ServerPool::new(
+                address.clone(),
+                user.clone(),
+                psql_user.clone(),
+                client_server_map.clone(),
+                pool_auth_hash.clone(),
+                match pool_config.plugins {
+                    Some(ref plugins) => Some(plugins.clone()),
+                    None => config.plugins.clone(),
+                },
+                pool_config.cleanup_server_connections,
+                pool_config.log_client_parameter_status_changes,
+                pool_config.prepared_statements_cache_size,
+            );
+
+            let connect_timeout = match user.connect_timeout {
+                Some(connect_timeout) => connect_timeout,
+                None => match pool_config.connect_timeout {
+                    Some(connect_timeout) => connect_timeout,
+                    None => config.general.connect_timeout,
+                },
+            };
+
+            let idle_timeout = match user.idle_timeout {
+                Some(idle_timeout) => idle_timeout,
+                None => match pool_config.idle_timeout {
+                    Some(idle_timeout) => idle_timeout,
+                    None => config.general.idle_timeout,
+                },
+            };
+
+            let server_lifetime = match user.server_lifetime {
+                Some(server_lifetime) => server_lifetime,
+                None => match pool_config.server_lifetime {
+                    Some(server_lifetime) => server_lifetime,
+                    None => config.general.server_lifetime,
+                },
+            };
+
+            let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
+                .iter()
+                .min()
+                .unwrap();
+
+            let queue_strategy = match config.general.server_round_robin {
+                true => QueueStrategy::Fifo,
+                false => QueueStrategy::Lifo,
+            };
+
+            debug!("[pool: {}][user: {}] Pool reaper rate: {}ms",
+                pool_name, user.username, reaper_rate);
+
+            let pool = Pool::builder()
+                .max_size(user.pool_size)
+                .min_idle(user.min_pool_size)
+                .connection_timeout(std::time::Duration::from_millis(connect_timeout))
+                .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
+                .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
+                .reaper_rate(std::time::Duration::from_millis(reaper_rate))
+                .queue_strategy(queue_strategy)
+                .test_on_check_out(false);
+
+            let pool = if config.general.validate_config {
+                pool.build(manager).await?
+            } else {
+                pool.build_unchecked(manager)
+            };
+
+            pools.push(pool);
+            servers.push(address);
+        }
+
+        shards.push(pools);
+        addresses.push(servers);
+        banlist.push(HashMap::new());
+    }
+    assert_eq!(shards.len(), addresses.len());
+    if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
+        info!("Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
+              pool_name, user.username);
+    }
+    let new_pool_hash_value = pool_config.hash_value();
+    
+    let pool = ConnectionPool {
+        databases: Arc::new(shards),
+        addresses: Arc::new(addresses),
+        banlist: Arc::new(RwLock::new(banlist)),
+        config_hash: new_pool_hash_value,
+        original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
+        auth_hash: pool_auth_hash,
+        settings: Arc::new(PoolSettings {
+            pool_mode: match user.pool_mode {
+                Some(pool_mode) => pool_mode,
+                None => pool_config.pool_mode,
+            },
+            load_balancing_mode: pool_config.load_balancing_mode,
+            // shards: pool_config.shards.clone(),
+            shards: shard_ids.len(),
+            user: user.clone(),
+            db: pool_name.to_string(),
+            default_role: match pool_config.default_role.as_str() {
+                "any" => None,
+                "replica" => Some(Role::Replica),
+                "primary" => Some(Role::Primary),
+                _ => unreachable!(),
+            },
+            query_parser_enabled: pool_config.query_parser_enabled,
+            query_parser_max_length: pool_config.query_parser_max_length,
+            query_parser_read_write_splitting: pool_config
+                .query_parser_read_write_splitting,
+            primary_reads_enabled: pool_config.primary_reads_enabled,
+            sharding_function: pool_config.sharding_function,
+            automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
+            healthcheck_delay: config.general.healthcheck_delay,
+            healthcheck_timeout: config.general.healthcheck_timeout,
+            ban_time: config.general.ban_time,
+            sharding_key_regex: pool_config
+                .sharding_key_regex
+                .clone()
+                .map(|regex| Regex::new(regex.as_str()).unwrap()),
+            shard_id_regex: pool_config
+                .shard_id_regex
+                .clone()
+                .map(|regex| Regex::new(regex.as_str()).unwrap()),
+            regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
+            default_shard: pool_config.default_shard,
+            auth_query: pool_config.auth_query.clone(),
+            auth_query_user: pool_config.auth_query_user.clone(),
+            auth_query_password: pool_config.auth_query_password.clone(),
+            proxy: pool_config.proxy.clone(),
+            plugins: match pool_config.plugins {
+                Some(ref plugins) => Some(plugins.clone()),
+                None => config.plugins.clone(),
+            },
+        }),
+        validated: Arc::new(AtomicBool::new(false)),
+        paused: Arc::new(AtomicBool::new(false)),
+        paused_waiter: Arc::new(Notify::new()),
+        prepared_statement_cache: match pool_config.prepared_statements_cache_size {
+            0 => None,
+            _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
+                pool_config.prepared_statements_cache_size,
+            )))),
+        },
+    };
+    // Connect to the servers to make sure pool configuration is valid
+    // before setting it globally.
+    // Do this async and somewhere else, we don't have to wait here.
+    if config.general.validate_config {
+        let validate_pool = pool.clone();
+        tokio::task::spawn(async move {
+            let _ = validate_pool.validate().await;
+        });
+    }
+
+    // There is one pool per database/user pair.
+    new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool.clone());
+
+    POOLS.store(Arc::new(new_pools.clone()));
+
+    Ok(Option::Some::<ConnectionPool>(pool))
+}

From 31a7750c1166a5fc2a5741f39291f8b76535f267 Mon Sep 17 00:00:00 2001
From: Ravi Soni <rv.soni@hotmail.com>
Date: Sat, 9 Nov 2024 21:23:58 +0530
Subject: [PATCH 3/4] Implemented PGBouncer like database proxy functionality.

---
 src/client.rs |  2 +-
 src/config.rs |  8 +++---
 src/pool.rs   | 67 ++++++++++++++++++++++++++++-----------------------
 3 files changed, 43 insertions(+), 34 deletions(-)

diff --git a/src/client.rs b/src/client.rs
index 7682aab8..3d064b28 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -20,7 +20,7 @@ use crate::config::{
 use crate::constants::*;
 use crate::messages::*;
 use crate::plugins::PluginOutput;
-use crate::pool::{get_pool, get_or_create_pool, ClientServerMap, ConnectionPool};
+use crate::pool::{get_or_create_pool, get_pool, ClientServerMap, ConnectionPool};
 use crate::query_router::{Command, QueryRouter};
 use crate::server::{Server, ServerParameters};
 use crate::stats::{ClientStats, ServerStats};
diff --git a/src/config.rs b/src/config.rs
index 7963a9a0..f26dc267 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -589,7 +589,6 @@ pub struct Pool {
     #[serde(default = "Pool::default_prepared_statements_cache_size")]
     pub prepared_statements_cache_size: usize,
 
-
     #[serde(default = "Pool::default_proxy")]
     pub proxy: bool,
 
@@ -646,7 +645,9 @@ impl Pool {
         0
     }
 
-    pub fn default_proxy() -> bool {false}
+    pub fn default_proxy() -> bool {
+        false
+    }
 
     pub fn validate(&mut self) -> Result<(), Error> {
         match self.default_role.as_ref() {
@@ -1235,7 +1236,8 @@ impl Config {
                 pool_name,
                 pool_config.pool_mode.to_string()
             );
-            info!("[pool: {}] Proxy mode: {}",
+            info!(
+                "[pool: {}] Proxy mode: {}",
                 pool_name,
                 pool_config.proxy.to_string()
             );
diff --git a/src/pool.rs b/src/pool.rs
index 7616faad..4b7ee8b9 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -297,9 +297,8 @@ impl ConnectionPool {
         for (pool_name, pool_config) in &config.pools {
             let new_pool_hash_value = pool_config.hash_value();
 
-            let is_proxy :bool = pool_config.proxy;
+            let is_proxy: bool = pool_config.proxy;
             if !is_proxy {
-
                 // There is one pool per database/user pair.
                 for user in pool_config.users.values() {
                     let old_pool_ref = get_pool(pool_name, &user.username);
@@ -399,7 +398,8 @@ impl ConnectionPool {
                             if let Some(apt) = &auth_passthrough {
                                 match apt.fetch_hash(&address).await {
                                     Ok(ok) => {
-                                        if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
+                                        if let Some(ref pool_auth_hash_value) =
+                                            *(pool_auth_hash.read())
                                         {
                                             if ok != *pool_auth_hash_value {
                                                 warn!(
@@ -485,9 +485,13 @@ impl ConnectionPool {
                             let pool = Pool::builder()
                                 .max_size(user.pool_size)
                                 .min_idle(user.min_pool_size)
-                                .connection_timeout(std::time::Duration::from_millis(connect_timeout))
+                                .connection_timeout(std::time::Duration::from_millis(
+                                    connect_timeout,
+                                ))
                                 .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
-                                .max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
+                                .max_lifetime(Some(std::time::Duration::from_millis(
+                                    server_lifetime,
+                                )))
                                 .reaper_rate(std::time::Duration::from_millis(reaper_rate))
                                 .queue_strategy(queue_strategy)
                                 .test_on_check_out(false);
@@ -591,7 +595,6 @@ impl ConnectionPool {
                     // There is one pool per database/user pair.
                     new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
                 }
-
             }
         }
 
@@ -1238,12 +1241,12 @@ pub async fn get_or_create_pool(db: &str, user: &str) -> Option<ConnectionPool>
 
         if pool_config.is_some() && pool_config?.proxy {
             info!("Using existing pool {}, Proxy is enabled", db);
-            
-            pool = match create_pool_for_proxy(db, user).await{
+
+            pool = match create_pool_for_proxy(db, user).await {
                 Ok(pool) => Option::from(pool.unwrap()),
-                Err(err) => None
+                Err(err) => None,
             };
-                        
+
             info!("Created a new pool {:?}", pool);
         }
     }
@@ -1256,7 +1259,10 @@ pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
     (*(*POOLS.load())).clone()
 }
 
-async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<ConnectionPool>), Error> {
+async fn create_pool_for_proxy(
+    db: &str,
+    psql_user: &str,
+) -> Result<(Option<ConnectionPool>), Error> {
     let config = get_config();
     let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
 
@@ -1272,7 +1278,7 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
 
     let mut user: User = pool_config.users.get("0").unwrap().clone();
     user.username = psql_user.to_string();
-    
+
     let mut shards = Vec::new();
     let mut addresses = Vec::new();
     let mut banlist = Vec::new();
@@ -1296,9 +1302,7 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
         for (address_index, server) in shard.servers.iter().enumerate() {
             let mut mirror_addresses = vec![];
             if let Some(mirror_settings_vec) = &shard.mirrors {
-                for (mirror_idx, mirror_settings) in
-                    mirror_settings_vec.iter().enumerate()
-                {
+                for (mirror_idx, mirror_settings) in mirror_settings_vec.iter().enumerate() {
                     if mirror_settings.mirroring_target_index != address_index {
                         continue;
                     }
@@ -1349,15 +1353,15 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
             if let Some(apt) = &auth_passthrough {
                 match apt.fetch_hash(&address).await {
                     Ok(ok) => {
-                        
-                        if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
-                        {
+                        if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
                             if ok != *pool_auth_hash_value {
-                                warn!("Hash is not the same across shards \
+                                warn!(
+                                    "Hash is not the same across shards \
                                     of the same pool, client auth will \
                                     be done using last obtained hash. \
                                     Server: {}:{}, Database: {}",
-                                    server.host, server.port, shard.database,);
+                                    server.host, server.port, shard.database,
+                                );
                             }
                         }
 
@@ -1369,11 +1373,11 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
                         }
                     }
                     Err(err) => warn!(
-                                    "Could not obtain password hashes \
+                        "Could not obtain password hashes \
                                         using auth_query config, ignoring. \
                                         Error: {:?}",
-                                    err,
-                                ),
+                        err,
+                    ),
                 }
             }
 
@@ -1426,8 +1430,10 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
                 false => QueueStrategy::Lifo,
             };
 
-            debug!("[pool: {}][user: {}] Pool reaper rate: {}ms",
-                pool_name, user.username, reaper_rate);
+            debug!(
+                "[pool: {}][user: {}] Pool reaper rate: {}ms",
+                pool_name, user.username, reaper_rate
+            );
 
             let pool = Pool::builder()
                 .max_size(user.pool_size)
@@ -1455,11 +1461,13 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
     }
     assert_eq!(shards.len(), addresses.len());
     if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
-        info!("Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
-              pool_name, user.username);
+        info!(
+            "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
+            pool_name, user.username
+        );
     }
     let new_pool_hash_value = pool_config.hash_value();
-    
+
     let pool = ConnectionPool {
         databases: Arc::new(shards),
         addresses: Arc::new(addresses),
@@ -1485,8 +1493,7 @@ async fn create_pool_for_proxy(db: &str, psql_user: &str) -> Result<(Option::<Co
             },
             query_parser_enabled: pool_config.query_parser_enabled,
             query_parser_max_length: pool_config.query_parser_max_length,
-            query_parser_read_write_splitting: pool_config
-                .query_parser_read_write_splitting,
+            query_parser_read_write_splitting: pool_config.query_parser_read_write_splitting,
             primary_reads_enabled: pool_config.primary_reads_enabled,
             sharding_function: pool_config.sharding_function,
             automatic_sharding_key: pool_config.automatic_sharding_key.clone(),

From 33befee28ff417b6758f5aaea7cdcb533f76216a Mon Sep 17 00:00:00 2001
From: Ravi Soni <rv.soni@hotmail.com>
Date: Sat, 1 Mar 2025 19:59:01 +0530
Subject: [PATCH 4/4] Rebased repo with latest upstrime

---
 pgcat.proxy.toml    |  33 ----
 pgcat.toml          | 359 +++-----------------------------------------
 pgcat.toml.original | 350 ++++++++++++++++++++++++++++++++++++++++++
 src/pool.rs         | 130 +++++-----------
 4 files changed, 405 insertions(+), 467 deletions(-)
 delete mode 100644 pgcat.proxy.toml
 create mode 100644 pgcat.toml.original

diff --git a/pgcat.proxy.toml b/pgcat.proxy.toml
deleted file mode 100644
index 71ef33d3..00000000
--- a/pgcat.proxy.toml
+++ /dev/null
@@ -1,33 +0,0 @@
-# This is an example of the most basic config
-# that will mimic what PgBouncer does in transaction mode with one server.
-
-[general]
-
-log_level = "DEBUG"
-
-host = "0.0.0.0"
-port = 6433
-admin_username = "pgcat"
-admin_password = "pgcat"
-
-[pools.pgml]
-auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
-
-# Be sure to grant this user LOGIN on Postgres
-auth_query_user = "postgres"
-auth_query_password = "postgres"
-
-proxy = true
-
-[pools.pgml.users.0]
-username = "postgres"
-#password = "postgres"
-pool_size = 10
-min_pool_size = 1
-pool_mode = "session"
-
-[pools.pgml.shards.0]
-servers = [
-  ["localhost", 5432, "primary"]
-]
-database = "postgres"
diff --git a/pgcat.toml b/pgcat.toml
index 87f2700c..71ef33d3 100644
--- a/pgcat.toml
+++ b/pgcat.toml
@@ -1,350 +1,33 @@
-#
-# PgCat config example.
-#
+# This is an example of the most basic config
+# that will mimic what PgBouncer does in transaction mode with one server.
 
-#
-# General pooler settings
 [general]
-# What IP to run on, 0.0.0.0 means accessible from everywhere.
-host = "0.0.0.0"
-
-# Port to run on, same as PgBouncer used in this example.
-port = 6432
-
-# Whether to enable prometheus exporter or not.
-enable_prometheus_exporter = true
-
-# Port at which prometheus exporter listens on.
-prometheus_exporter_port = 9930
-
-# How long to wait before aborting a server connection (ms).
-connect_timeout = 5000 # milliseconds
-
-# How long an idle connection with a server is left open (ms).
-idle_timeout = 30000 # milliseconds
-
-# Max connection lifetime before it's closed, even if actively used.
-server_lifetime = 86400000 # 24 hours
-
-# How long a client is allowed to be idle while in a transaction (ms).
-idle_client_in_transaction_timeout = 0 # milliseconds
-
-# How much time to give the health check query to return with a result (ms).
-healthcheck_timeout = 1000 # milliseconds
-
-# How long to keep connection available for immediate re-use, without running a healthcheck query on it
-healthcheck_delay = 30000 # milliseconds
-
-# How much time to give clients during shutdown before forcibly killing client connections (ms).
-shutdown_timeout = 60000 # milliseconds
-
-# How long to ban a server if it fails a health check (seconds).
-ban_time = 60 # seconds
-
-# If we should log client connections
-log_client_connections = false
-
-# If we should log client disconnections
-log_client_disconnections = false
-
-# When set to true, PgCat reloads configs if it detects a change in the config file.
-autoreload = 15000
-
-# Number of worker threads the Runtime will use (4 by default).
-worker_threads = 5
-
-# Number of seconds of connection idleness to wait before sending a keepalive packet to the server.
-tcp_keepalives_idle = 5
-# Number of unacknowledged keepalive packets allowed before giving up and closing the connection.
-tcp_keepalives_count = 5
-# Number of seconds between keepalive packets.
-tcp_keepalives_interval = 5
-
-# Path to TLS Certificate file to use for TLS connections
-# tls_certificate = ".circleci/server.cert"
-# Path to TLS private key file to use for TLS connections
-# tls_private_key = ".circleci/server.key"
-
-# Enable/disable server TLS
-server_tls = false
-
-# Verify server certificate is completely authentic.
-verify_server_certificate = false
-
-# User name to access the virtual administrative database (pgbouncer or pgcat)
-# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
-admin_username = "admin_user"
-# Password to access the virtual administrative database
-admin_password = "admin_pass"
-
-# Default plugins that are configured on all pools.
-[plugins]
-
-# Prewarmer plugin that runs queries on server startup, before giving the connection
-# to the client.
-[plugins.prewarmer]
-enabled = false
-queries = [
-  "SELECT pg_prewarm('pgbench_accounts')",
-]
-
-# Log all queries to stdout.
-[plugins.query_logger]
-enabled = false
-
-# Block access to tables that Postgres does not allow us to control.
-[plugins.table_access]
-enabled = false
-tables = [
-  "pg_user",
-  "pg_roles",
-  "pg_database",
-]
-
-# Intercept user queries and give a fake reply.
-[plugins.intercept]
-enabled = true
-
-[plugins.intercept.queries.0]
-
-query = "select current_database() as a, current_schemas(false) as b"
-schema = [
-  ["a", "text"],
-  ["b", "text"],
-]
-result = [
-  ["${DATABASE}", "{public}"],
-]
-
-[plugins.intercept.queries.1]
-
-query = "select current_database(), current_schema(), current_user"
-schema = [
-  ["current_database", "text"],
-  ["current_schema", "text"],
-  ["current_user", "text"],
-]
-result = [
-  ["${DATABASE}", "public", "${USER}"],
-]
-
-
-# pool configs are structured as pool.<pool_name>
-# the pool_name is what clients use as database name when connecting.
-# For a pool named `sharded_db`, clients access that pool using connection string like
-# `postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db`
-[pools.sharded_db]
-# Pool mode (see PgBouncer docs for more).
-# `session` one server connection per connected client
-# `transaction` one server connection per client transaction
-pool_mode = "transaction"
-
-# Load balancing mode
-# `random` selects the server at random
-# `loc` selects the server with the least outstanding busy conncetions
-load_balancing_mode = "random"
-
-# If the client doesn't specify, PgCat routes traffic to this role by default.
-# `any` round-robin between primary and replicas,
-# `replica` round-robin between replicas only without touching the primary,
-# `primary` all queries go to the primary unless otherwise specified.
-default_role = "any"
-
-# Prepared statements cache size.
-# TODO: update documentation
-prepared_statements_cache_size = 500
 
-# If Query Parser is enabled, we'll attempt to parse
-# every incoming query to determine if it's a read or a write.
-# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
-# we'll direct it to the primary.
-query_parser_enabled = true
+log_level = "DEBUG"
 
-# If the query parser is enabled and this setting is enabled, we'll attempt to
-# infer the role from the query itself.
-query_parser_read_write_splitting = true
-
-# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
-# load balancing of read queries. Otherwise, the primary will only be used for write
-# queries. The primary can always be explicitly selected with our custom protocol.
-primary_reads_enabled = true
-
-# Allow sharding commands to be passed as statement comments instead of
-# separate commands. If these are unset this functionality is disabled.
-# sharding_key_regex = '/\* sharding_key: (\d+) \*/'
-# shard_id_regex = '/\* shard_id: (\d+) \*/'
-# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
-
-# Defines the behavior when no shard is selected in a sharded system.
-# `random`: picks a shard at random
-# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
-# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
-# default_shard = "shard_0"
-
-# So what if you wanted to implement a different hashing function,
-# or you've already built one and you want this pooler to use it?
-# Current options:
-# `pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
-# `sha1`: A hashing function based on SHA1
-sharding_function = "pg_bigint_hash"
-
-# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
-# established using the database configured in the pool. This parameter is inherited by every pool
-# and can be redefined in pool configuration.
-# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
-
-# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
-# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
-# This parameter is inherited by every pool and can be redefined in pool configuration.
-# auth_query_user = "sharding_user"
-
-# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
-# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
-# This parameter is inherited by every pool and can be redefined in pool configuration.
-# auth_query_password = "sharding_user"
-
-# Automatically parse this from queries and route queries to the right shard!
-# automatic_sharding_key = "data.id"
-
-# Idle timeout can be overwritten in the pool
-idle_timeout = 40000
-
-# Connect timeout can be overwritten in the pool
-connect_timeout = 3000
-
-# When enabled, ip resolutions for server connections specified using hostnames will be cached
-# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
-# old ip connections are closed (gracefully) and new connections will start using new ip.
-# dns_cache_enabled = false
-
-# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
-# dns_max_ttl = 30
-
-# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
-# so all plugins have to be configured here again.
-[pool.sharded_db.plugins]
-
-[pools.sharded_db.plugins.prewarmer]
-enabled = true
-queries = [
-  "SELECT pg_prewarm('pgbench_accounts')",
-]
-
-[pools.sharded_db.plugins.query_logger]
-enabled = false
-
-[pools.sharded_db.plugins.table_access]
-enabled = false
-tables = [
-  "pg_user",
-  "pg_roles",
-  "pg_database",
-]
-
-[pools.sharded_db.plugins.intercept]
-enabled = true
-
-[pools.sharded_db.plugins.intercept.queries.0]
-
-query = "select current_database() as a, current_schemas(false) as b"
-schema = [
-  ["a", "text"],
-  ["b", "text"],
-]
-result = [
-  ["${DATABASE}", "{public}"],
-]
-
-[pools.sharded_db.plugins.intercept.queries.1]
-
-query = "select current_database(), current_schema(), current_user"
-schema = [
-  ["current_database", "text"],
-  ["current_schema", "text"],
-  ["current_user", "text"],
-]
-result = [
-  ["${DATABASE}", "public", "${USER}"],
-]
-
-# User configs are structured as pool.<pool_name>.users.<user_index>
-# This section holds the credentials for users that may connect to this cluster
-[pools.sharded_db.users.0]
-# PostgreSQL username used to authenticate the user and connect to the server
-# if `server_username` is not set.
-username = "sharding_user"
-
-# PostgreSQL password used to authenticate the user and connect to the server
-# if `server_password` is not set.
-password = "sharding_user"
-
-pool_mode = "transaction"
-
-# PostgreSQL username used to connect to the server.
-# server_username = "another_user"
-
-# PostgreSQL password used to connect to the server.
-# server_password = "another_password"
-
-# Maximum number of server connections that can be established for this user
-# The maximum number of connection from a single Pgcat process to any database in the cluster
-# is the sum of pool_size across all users.
-pool_size = 9
-
-
-# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
-# 0 means it is disabled.
-statement_timeout = 0
-
-[pools.sharded_db.users.1]
-username = "other_user"
-password = "other_user"
-pool_size = 21
-statement_timeout = 15000
-connect_timeout = 1000
-idle_timeout = 1000
-
-# Shard configs are structured as pool.<pool_name>.shards.<shard_id>
-# Each shard config contains a list of servers that make up the shard
-# and the database name to use.
-[pools.sharded_db.shards.0]
-# Array of servers in the shard, each server entry is an array of `[host, port, role]`
-servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
-
-# Array of mirrors for the shard, each mirror entry is an array of `[host, port, index of server in servers array]`
-# Traffic hitting the server identified by the index will be sent to the mirror.
-# mirrors = [["1.2.3.4", 5432, 0], ["1.2.3.4", 5432, 1]]
-
-# Database name (e.g. "postgres")
-database = "shard0"
+host = "0.0.0.0"
+port = 6433
+admin_username = "pgcat"
+admin_password = "pgcat"
 
-[pools.sharded_db.shards.1]
-servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
-database = "shard1"
+[pools.pgml]
+auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
 
-[pools.sharded_db.shards.2]
-servers = [["127.0.0.1", 5432, "primary" ], ["localhost", 5432, "replica" ]]
-database = "shard2"
+# Be sure to grant this user LOGIN on Postgres
+auth_query_user = "postgres"
+auth_query_password = "postgres"
 
+proxy = true
 
-[pools.simple_db]
+[pools.pgml.users.0]
+username = "postgres"
+#password = "postgres"
+pool_size = 10
+min_pool_size = 1
 pool_mode = "session"
-default_role = "primary"
-query_parser_enabled = true
-primary_reads_enabled = true
-sharding_function = "pg_bigint_hash"
-
-[pools.simple_db.users.0]
-username = "simple_user"
-password = "simple_user"
-pool_size = 5
-min_pool_size = 3
-server_lifetime = 60000
-statement_timeout = 0
 
-[pools.simple_db.shards.0]
+[pools.pgml.shards.0]
 servers = [
-    [ "127.0.0.1", 5432, "primary" ],
-    [ "localhost", 5432, "replica" ]
+  ["localhost", 5432, "primary"]
 ]
-database = "some_db"
+database = "postgres"
diff --git a/pgcat.toml.original b/pgcat.toml.original
new file mode 100644
index 00000000..87f2700c
--- /dev/null
+++ b/pgcat.toml.original
@@ -0,0 +1,350 @@
+#
+# PgCat config example.
+#
+
+#
+# General pooler settings
+[general]
+# What IP to run on, 0.0.0.0 means accessible from everywhere.
+host = "0.0.0.0"
+
+# Port to run on, same as PgBouncer used in this example.
+port = 6432
+
+# Whether to enable prometheus exporter or not.
+enable_prometheus_exporter = true
+
+# Port at which prometheus exporter listens on.
+prometheus_exporter_port = 9930
+
+# How long to wait before aborting a server connection (ms).
+connect_timeout = 5000 # milliseconds
+
+# How long an idle connection with a server is left open (ms).
+idle_timeout = 30000 # milliseconds
+
+# Max connection lifetime before it's closed, even if actively used.
+server_lifetime = 86400000 # 24 hours
+
+# How long a client is allowed to be idle while in a transaction (ms).
+idle_client_in_transaction_timeout = 0 # milliseconds
+
+# How much time to give the health check query to return with a result (ms).
+healthcheck_timeout = 1000 # milliseconds
+
+# How long to keep connection available for immediate re-use, without running a healthcheck query on it
+healthcheck_delay = 30000 # milliseconds
+
+# How much time to give clients during shutdown before forcibly killing client connections (ms).
+shutdown_timeout = 60000 # milliseconds
+
+# How long to ban a server if it fails a health check (seconds).
+ban_time = 60 # seconds
+
+# If we should log client connections
+log_client_connections = false
+
+# If we should log client disconnections
+log_client_disconnections = false
+
+# When set to true, PgCat reloads configs if it detects a change in the config file.
+autoreload = 15000
+
+# Number of worker threads the Runtime will use (4 by default).
+worker_threads = 5
+
+# Number of seconds of connection idleness to wait before sending a keepalive packet to the server.
+tcp_keepalives_idle = 5
+# Number of unacknowledged keepalive packets allowed before giving up and closing the connection.
+tcp_keepalives_count = 5
+# Number of seconds between keepalive packets.
+tcp_keepalives_interval = 5
+
+# Path to TLS Certificate file to use for TLS connections
+# tls_certificate = ".circleci/server.cert"
+# Path to TLS private key file to use for TLS connections
+# tls_private_key = ".circleci/server.key"
+
+# Enable/disable server TLS
+server_tls = false
+
+# Verify server certificate is completely authentic.
+verify_server_certificate = false
+
+# User name to access the virtual administrative database (pgbouncer or pgcat)
+# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
+admin_username = "admin_user"
+# Password to access the virtual administrative database
+admin_password = "admin_pass"
+
+# Default plugins that are configured on all pools.
+[plugins]
+
+# Prewarmer plugin that runs queries on server startup, before giving the connection
+# to the client.
+[plugins.prewarmer]
+enabled = false
+queries = [
+  "SELECT pg_prewarm('pgbench_accounts')",
+]
+
+# Log all queries to stdout.
+[plugins.query_logger]
+enabled = false
+
+# Block access to tables that Postgres does not allow us to control.
+[plugins.table_access]
+enabled = false
+tables = [
+  "pg_user",
+  "pg_roles",
+  "pg_database",
+]
+
+# Intercept user queries and give a fake reply.
+[plugins.intercept]
+enabled = true
+
+[plugins.intercept.queries.0]
+
+query = "select current_database() as a, current_schemas(false) as b"
+schema = [
+  ["a", "text"],
+  ["b", "text"],
+]
+result = [
+  ["${DATABASE}", "{public}"],
+]
+
+[plugins.intercept.queries.1]
+
+query = "select current_database(), current_schema(), current_user"
+schema = [
+  ["current_database", "text"],
+  ["current_schema", "text"],
+  ["current_user", "text"],
+]
+result = [
+  ["${DATABASE}", "public", "${USER}"],
+]
+
+
+# pool configs are structured as pool.<pool_name>
+# the pool_name is what clients use as database name when connecting.
+# For a pool named `sharded_db`, clients access that pool using connection string like
+# `postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db`
+[pools.sharded_db]
+# Pool mode (see PgBouncer docs for more).
+# `session` one server connection per connected client
+# `transaction` one server connection per client transaction
+pool_mode = "transaction"
+
+# Load balancing mode
+# `random` selects the server at random
+# `loc` selects the server with the least outstanding busy conncetions
+load_balancing_mode = "random"
+
+# If the client doesn't specify, PgCat routes traffic to this role by default.
+# `any` round-robin between primary and replicas,
+# `replica` round-robin between replicas only without touching the primary,
+# `primary` all queries go to the primary unless otherwise specified.
+default_role = "any"
+
+# Prepared statements cache size.
+# TODO: update documentation
+prepared_statements_cache_size = 500
+
+# If Query Parser is enabled, we'll attempt to parse
+# every incoming query to determine if it's a read or a write.
+# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
+# we'll direct it to the primary.
+query_parser_enabled = true
+
+# If the query parser is enabled and this setting is enabled, we'll attempt to
+# infer the role from the query itself.
+query_parser_read_write_splitting = true
+
+# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
+# load balancing of read queries. Otherwise, the primary will only be used for write
+# queries. The primary can always be explicitly selected with our custom protocol.
+primary_reads_enabled = true
+
+# Allow sharding commands to be passed as statement comments instead of
+# separate commands. If these are unset this functionality is disabled.
+# sharding_key_regex = '/\* sharding_key: (\d+) \*/'
+# shard_id_regex = '/\* shard_id: (\d+) \*/'
+# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
+
+# Defines the behavior when no shard is selected in a sharded system.
+# `random`: picks a shard at random
+# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
+# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
+# default_shard = "shard_0"
+
+# So what if you wanted to implement a different hashing function,
+# or you've already built one and you want this pooler to use it?
+# Current options:
+# `pg_bigint_hash`: PARTITION BY HASH (Postgres hashing function)
+# `sha1`: A hashing function based on SHA1
+sharding_function = "pg_bigint_hash"
+
+# Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
+# established using the database configured in the pool. This parameter is inherited by every pool
+# and can be redefined in pool configuration.
+# auth_query="SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
+
+# User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
+# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
+# This parameter is inherited by every pool and can be redefined in pool configuration.
+# auth_query_user = "sharding_user"
+
+# Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
+# specified in `auth_query_user`. The connection will be established using the database configured in the pool.
+# This parameter is inherited by every pool and can be redefined in pool configuration.
+# auth_query_password = "sharding_user"
+
+# Automatically parse this from queries and route queries to the right shard!
+# automatic_sharding_key = "data.id"
+
+# Idle timeout can be overwritten in the pool
+idle_timeout = 40000
+
+# Connect timeout can be overwritten in the pool
+connect_timeout = 3000
+
+# When enabled, ip resolutions for server connections specified using hostnames will be cached
+# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
+# old ip connections are closed (gracefully) and new connections will start using new ip.
+# dns_cache_enabled = false
+
+# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
+# dns_max_ttl = 30
+
+# Plugins can be configured on a pool-per-pool basis. This overrides the global plugins setting,
+# so all plugins have to be configured here again.
+[pool.sharded_db.plugins]
+
+[pools.sharded_db.plugins.prewarmer]
+enabled = true
+queries = [
+  "SELECT pg_prewarm('pgbench_accounts')",
+]
+
+[pools.sharded_db.plugins.query_logger]
+enabled = false
+
+[pools.sharded_db.plugins.table_access]
+enabled = false
+tables = [
+  "pg_user",
+  "pg_roles",
+  "pg_database",
+]
+
+[pools.sharded_db.plugins.intercept]
+enabled = true
+
+[pools.sharded_db.plugins.intercept.queries.0]
+
+query = "select current_database() as a, current_schemas(false) as b"
+schema = [
+  ["a", "text"],
+  ["b", "text"],
+]
+result = [
+  ["${DATABASE}", "{public}"],
+]
+
+[pools.sharded_db.plugins.intercept.queries.1]
+
+query = "select current_database(), current_schema(), current_user"
+schema = [
+  ["current_database", "text"],
+  ["current_schema", "text"],
+  ["current_user", "text"],
+]
+result = [
+  ["${DATABASE}", "public", "${USER}"],
+]
+
+# User configs are structured as pool.<pool_name>.users.<user_index>
+# This section holds the credentials for users that may connect to this cluster
+[pools.sharded_db.users.0]
+# PostgreSQL username used to authenticate the user and connect to the server
+# if `server_username` is not set.
+username = "sharding_user"
+
+# PostgreSQL password used to authenticate the user and connect to the server
+# if `server_password` is not set.
+password = "sharding_user"
+
+pool_mode = "transaction"
+
+# PostgreSQL username used to connect to the server.
+# server_username = "another_user"
+
+# PostgreSQL password used to connect to the server.
+# server_password = "another_password"
+
+# Maximum number of server connections that can be established for this user
+# The maximum number of connection from a single Pgcat process to any database in the cluster
+# is the sum of pool_size across all users.
+pool_size = 9
+
+
+# Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
+# 0 means it is disabled.
+statement_timeout = 0
+
+[pools.sharded_db.users.1]
+username = "other_user"
+password = "other_user"
+pool_size = 21
+statement_timeout = 15000
+connect_timeout = 1000
+idle_timeout = 1000
+
+# Shard configs are structured as pool.<pool_name>.shards.<shard_id>
+# Each shard config contains a list of servers that make up the shard
+# and the database name to use.
+[pools.sharded_db.shards.0]
+# Array of servers in the shard, each server entry is an array of `[host, port, role]`
+servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
+
+# Array of mirrors for the shard, each mirror entry is an array of `[host, port, index of server in servers array]`
+# Traffic hitting the server identified by the index will be sent to the mirror.
+# mirrors = [["1.2.3.4", 5432, 0], ["1.2.3.4", 5432, 1]]
+
+# Database name (e.g. "postgres")
+database = "shard0"
+
+[pools.sharded_db.shards.1]
+servers = [["127.0.0.1", 5432, "primary"], ["localhost", 5432, "replica"]]
+database = "shard1"
+
+[pools.sharded_db.shards.2]
+servers = [["127.0.0.1", 5432, "primary" ], ["localhost", 5432, "replica" ]]
+database = "shard2"
+
+
+[pools.simple_db]
+pool_mode = "session"
+default_role = "primary"
+query_parser_enabled = true
+primary_reads_enabled = true
+sharding_function = "pg_bigint_hash"
+
+[pools.simple_db.users.0]
+username = "simple_user"
+password = "simple_user"
+pool_size = 5
+min_pool_size = 3
+server_lifetime = 60000
+statement_timeout = 0
+
+[pools.simple_db.shards.0]
+servers = [
+    [ "127.0.0.1", 5432, "primary" ],
+    [ "localhost", 5432, "replica" ]
+]
+database = "some_db"
diff --git a/src/pool.rs b/src/pool.rs
index ac93a087..a7d484d2 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -557,6 +557,7 @@ impl ConnectionPool {
                                 None => pool_config.pool_mode,
                             },
                             load_balancing_mode: pool_config.load_balancing_mode,
+                            checkout_failure_limit: pool_config.checkout_failure_limit,
                             // shards: pool_config.shards.clone(),
                             shards: shard_ids.len(),
                             user: user.clone(),
@@ -573,6 +574,10 @@ impl ConnectionPool {
                                 .query_parser_read_write_splitting,
                             primary_reads_enabled: pool_config.primary_reads_enabled,
                             sharding_function: pool_config.sharding_function,
+                            db_activity_based_routing: pool_config.db_activity_based_routing,
+                            db_activity_init_delay: pool_config.db_activity_init_delay,
+                            db_activity_ttl: pool_config.db_activity_ttl,
+                            table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl,
                             automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
                             healthcheck_delay: config.general.healthcheck_delay,
                             healthcheck_timeout: config.general.healthcheck_timeout,
@@ -620,94 +625,6 @@ impl ConnectionPool {
                     // There is one pool per database/user pair.
                     new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
                 }
-
-                assert_eq!(shards.len(), addresses.len());
-                if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
-                    info!(
-                        "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
-                        pool_name, user.username
-                    );
-                }
-
-                let pool = ConnectionPool {
-                    databases: Arc::new(shards),
-                    addresses: Arc::new(addresses),
-                    banlist: Arc::new(RwLock::new(banlist)),
-                    config_hash: new_pool_hash_value,
-                    original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())),
-                    auth_hash: pool_auth_hash,
-                    settings: Arc::new(PoolSettings {
-                        pool_mode: match user.pool_mode {
-                            Some(pool_mode) => pool_mode,
-                            None => pool_config.pool_mode,
-                        },
-                        load_balancing_mode: pool_config.load_balancing_mode,
-                        checkout_failure_limit: pool_config.checkout_failure_limit,
-                        // shards: pool_config.shards.clone(),
-                        shards: shard_ids.len(),
-                        user: user.clone(),
-                        db: pool_name.clone(),
-                        default_role: match pool_config.default_role.as_str() {
-                            "any" => None,
-                            "replica" => Some(Role::Replica),
-                            "primary" => Some(Role::Primary),
-                            _ => unreachable!(),
-                        },
-                        query_parser_enabled: pool_config.query_parser_enabled,
-                        query_parser_max_length: pool_config.query_parser_max_length,
-                        query_parser_read_write_splitting: pool_config
-                            .query_parser_read_write_splitting,
-                        primary_reads_enabled: pool_config.primary_reads_enabled,
-                        sharding_function: pool_config.sharding_function,
-                        db_activity_based_routing: pool_config.db_activity_based_routing,
-                        db_activity_init_delay: pool_config.db_activity_init_delay,
-                        db_activity_ttl: pool_config.db_activity_ttl,
-                        table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl,
-                        automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
-                        healthcheck_delay: config.general.healthcheck_delay,
-                        healthcheck_timeout: config.general.healthcheck_timeout,
-                        ban_time: config.general.ban_time,
-                        sharding_key_regex: pool_config
-                            .sharding_key_regex
-                            .clone()
-                            .map(|regex| Regex::new(regex.as_str()).unwrap()),
-                        shard_id_regex: pool_config
-                            .shard_id_regex
-                            .clone()
-                            .map(|regex| Regex::new(regex.as_str()).unwrap()),
-                        regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
-                        default_shard: pool_config.default_shard,
-                        auth_query: pool_config.auth_query.clone(),
-                        auth_query_user: pool_config.auth_query_user.clone(),
-                        auth_query_password: pool_config.auth_query_password.clone(),
-                        plugins: match pool_config.plugins {
-                            Some(ref plugins) => Some(plugins.clone()),
-                            None => config.plugins.clone(),
-                        },
-                    }),
-                    validated: Arc::new(AtomicBool::new(false)),
-                    paused: Arc::new(AtomicBool::new(false)),
-                    paused_waiter: Arc::new(Notify::new()),
-                    prepared_statement_cache: match pool_config.prepared_statements_cache_size {
-                        0 => None,
-                        _ => Some(Arc::new(Mutex::new(PreparedStatementCache::new(
-                            pool_config.prepared_statements_cache_size,
-                        )))),
-                    },
-                };
-
-                // Connect to the servers to make sure pool configuration is valid
-                // before setting it globally.
-                // Do this async and somewhere else, we don't have to wait here.
-                if config.general.validate_config {
-                    let validate_pool = pool.clone();
-                    tokio::task::spawn(async move {
-                        let _ = validate_pool.validate().await;
-                    });
-                }
-
-                // There is one pool per database/user pair.
-                new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
             }
         }
 
@@ -1357,7 +1274,7 @@ pub async fn get_or_create_pool(db: &str, user: &str) -> Option<ConnectionPool>
 
             pool = match create_pool_for_proxy(db, user).await {
                 Ok(pool) => Option::from(pool.unwrap()),
-                Err(err) => None,
+                Err(_err) => None,
             };
 
             info!("Created a new pool {:?}", pool);
@@ -1375,11 +1292,10 @@ pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
 async fn create_pool_for_proxy(
     db: &str,
     psql_user: &str,
-) -> Result<(Option<ConnectionPool>), Error> {
+) -> Result<Option<ConnectionPool>, Error> {
     let config = get_config();
     let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
 
-    let mut new_pools = HashMap::new();
     let mut address_id: usize = 0;
 
     let pool_config_opt = config.pools.get_key_value(db);
@@ -1497,7 +1413,7 @@ async fn create_pool_for_proxy(
             let manager = ServerPool::new(
                 address.clone(),
                 user.clone(),
-                psql_user.clone(),
+                psql_user,
                 client_server_map.clone(),
                 pool_auth_hash.clone(),
                 match pool_config.plugins {
@@ -1594,6 +1510,7 @@ async fn create_pool_for_proxy(
                 None => pool_config.pool_mode,
             },
             load_balancing_mode: pool_config.load_balancing_mode,
+            checkout_failure_limit: pool_config.checkout_failure_limit,
             // shards: pool_config.shards.clone(),
             shards: shard_ids.len(),
             user: user.clone(),
@@ -1609,6 +1526,10 @@ async fn create_pool_for_proxy(
             query_parser_read_write_splitting: pool_config.query_parser_read_write_splitting,
             primary_reads_enabled: pool_config.primary_reads_enabled,
             sharding_function: pool_config.sharding_function,
+            db_activity_based_routing: pool_config.db_activity_based_routing,
+            db_activity_init_delay: pool_config.db_activity_init_delay,
+            db_activity_ttl: pool_config.db_activity_ttl,
+            table_mutation_cache_ms_ttl: pool_config.table_mutation_cache_ms_ttl,
             automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
             healthcheck_delay: config.general.healthcheck_delay,
             healthcheck_timeout: config.general.healthcheck_timeout,
@@ -1652,10 +1573,27 @@ async fn create_pool_for_proxy(
         });
     }
 
-    // There is one pool per database/user pair.
-    new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool.clone());
-
-    POOLS.store(Arc::new(new_pools.clone()));
+    add_connection_pool(PoolIdentifier::new(pool_name, &user.username), pool.clone());
 
     Ok(Option::Some::<ConnectionPool>(pool))
 }
+
+fn add_connection_pool(id: PoolIdentifier, pool: ConnectionPool) {
+    loop {
+        let old_map = POOLS.load(); // Load current Arc<HashMap<..>>
+        let mut new_map = (**old_map).clone(); // Clone and modify
+        new_map.insert(id.clone(), pool.clone());
+
+        let new_arc = Arc::new(new_map); // Wrap in Arc
+
+        // Attempt atomic swap
+        let previous = POOLS.compare_and_swap(&old_map, new_arc);
+
+        // Check if swap was successful
+        if Arc::ptr_eq(&previous, &old_map) {
+            break; // Success, exit loop
+        }
+
+        // Otherwise, retry (another thread modified it)
+    }
+}