From 31c2d433479712fddc55cbf010b7a9e2d961909e Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Fri, 7 Feb 2025 11:26:50 -0600 Subject: [PATCH 01/11] Implement connection service file functionality --- asyncpg/connect_utils.py | 128 ++++++++++++++++++++++++++++++++++++++- asyncpg/connection.py | 6 ++ tests/test_connect.py | 70 ++++++++++++++++++++- 3 files changed, 199 insertions(+), 5 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c65f68a6..98205e6c 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import configparser import collections from collections.abc import Callable import enum @@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum): PGPASSFILE = '.pgpass' +PG_SERVICEFILE = '.pg_service.conf' + + def _read_password_file(passfile: pathlib.Path) \ -> typing.List[typing.Tuple[str, ...]]: @@ -268,7 +272,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, + password, passfile, database, ssl, service, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib): # `auth_hosts` is the version of host information for the purposes @@ -278,6 +282,120 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ssl_min_protocol_version = ssl_max_protocol_version = None sslnegotiation = None + if dsn: + parsed = urllib.parse.urlparse(dsn) + if parsed.query: + query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) + for key, val in query.items(): + if isinstance(val, list): + query[key] = val[-1] + + if 'service' in query: + val = query.pop('service') + if not service and val: + service = val + + connection_service_file = os.getenv('PGSERVICEFILE') + if connection_service_file is None: + homedir = compat.get_pg_home_directory() + if homedir: + connection_service_file = homedir / PG_SERVICEFILE + else: + connection_service_file = None + else: + connection_service_file = pathlib.Path(connection_service_file) + + if connection_service_file is not None and service is not None: + # TODO Open and parse connection service file + pg_service = configparser.ConfigParser() + pg_service.read(connection_service_file) + if service in pg_service.sections(): + service_params = pg_service[service] + if 'port' in service_params: + val = service_params.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in service_params: + val = service_params.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in service_params: + val = service_params.pop('dbname') + if database is None: + database = val + + if 'database' in service_params: + val = service_params.pop('database') + if database is None: + database = val + + if 'user' in service_params: + val = service_params.pop('user') + if user is None: + user = val + + if 'password' in service_params: + val = service_params.pop('password') + if password is None: + password = val + + if 'passfile' in service_params: + val = service_params.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in service_params: + val = service_params.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in service_params: + sslcert = service_params.pop('sslcert') + + if 'sslkey' in service_params: + sslkey = service_params.pop('sslkey') + + if 'sslrootcert' in service_params: + sslrootcert = service_params.pop('sslrootcert') + + if 'sslnegotiation' in service_params: + sslnegotiation = service_params.pop('sslnegotiation') + + if 'sslcrl' in service_params: + sslcrl = service_params.pop('sslcrl') + + if 'sslpassword' in service_params: + sslpassword = service_params.pop('sslpassword') + + if 'ssl_min_protocol_version' in service_params: + ssl_min_protocol_version = service_params.pop( + 'ssl_min_protocol_version' + ) + + if 'ssl_max_protocol_version' in service_params: + ssl_max_protocol_version = service_params.pop( + 'ssl_max_protocol_version' + ) + + if 'target_session_attrs' in service_params: + dsn_target_session_attrs = service_params.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if 'krbsrvname' in service_params: + val = service_params.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + + if 'gsslib' in service_params: + val = service_params.pop('gsslib') + if gsslib is None: + gsslib = val + if dsn: parsed = urllib.parse.urlparse(dsn) @@ -406,6 +524,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if gsslib is None: gsslib = val + if 'service' in query: + val = query.pop('service') + if query: if server_settings is None: server_settings = query @@ -491,6 +612,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, database=database, user=user, passfile=passfile) + addrs = [] have_tcp_addrs = False for h, p in zip(host, port): @@ -724,7 +846,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs, krbsrvname, gsslib, service): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -754,7 +876,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, service=service) config = _ClientConfiguration( command_timeout=command_timeout, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 3a86466c..0ee87861 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2074,6 +2074,7 @@ async def _do_execute( async def connect(dsn=None, *, host=None, port=None, user=None, password=None, passfile=None, + service=None, database=None, loop=None, timeout=60, @@ -2183,6 +2184,10 @@ async def connect(dsn=None, *, (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` on Windows). + :param service: + The name of the postgres connection service stored in the postgres + connection service file. + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -2428,6 +2433,7 @@ async def connect(dsn=None, *, user=user, password=password, passfile=passfile, + service=service, ssl=ssl, direct_tls=direct_tls, database=database, diff --git a/tests/test_connect.py b/tests/test_connect.py index 0037ee5e..4ac46578 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1116,7 +1116,8 @@ def run_testcase(self, testcase): env = testcase.get('env', {}) test_env = {'PGHOST': None, 'PGPORT': None, 'PGUSER': None, 'PGPASSWORD': None, - 'PGDATABASE': None, 'PGSSLMODE': None} + 'PGDATABASE': None, 'PGSSLMODE': None, + 'PGSERVICE': None, } test_env.update(env) dsn = testcase.get('dsn') @@ -1132,6 +1133,7 @@ def run_testcase(self, testcase): target_session_attrs = testcase.get('target_session_attrs') krbsrvname = testcase.get('krbsrvname') gsslib = testcase.get('gsslib') + service = testcase.get('service') expected = testcase.get('result') expected_error = testcase.get('error') @@ -1157,7 +1159,7 @@ def run_testcase(self, testcase): direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, service=service) params = { k: v for k, v in params._asdict().items() @@ -1236,6 +1238,70 @@ def test_connect_params(self): for testcase in self.TESTS: self.run_testcase(testcase) + def test_connect_connection_service_file(self): + connection_service_file = tempfile.NamedTemporaryFile('w+t', delete=False) + connection_service_file.write(textwrap.dedent(f''' +[test_service_dbname] +port=5433 +host=somehost +dbname=test_dbname +user=admin +password=test_password +target_session_attrs=primary +krbsrvname=fakekrbsrvname +gsslib=sspi + +[test_service_database] +port=5433 +host=somehost +database=test_dbname +user=admin +password=test_password +target_session_attrs=primary +krbsrvname=fakekrbsrvname +gsslib=sspi + ''')) + connection_service_file.close() + os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR) + try: + # passfile path in env + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_database', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + finally: + os.unlink(connection_service_file.name) + def test_connect_pgpass_regular(self): passfile = tempfile.NamedTemporaryFile('w+t', delete=False) passfile.write(textwrap.dedent(R''' From 1f87efd3fc94b3620deec48c4ee6561adc93b416 Mon Sep 17 00:00:00 2001 From: Andrew Jackson <46945903+AndrewJackson2020@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:54:57 -0500 Subject: [PATCH 02/11] Update asyncpg/connect_utils.py Co-authored-by: Elvis Pranskevichus --- asyncpg/connect_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 98205e6c..25488003 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -526,6 +526,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'service' in query: val = query.pop('service') + if service is None: + service = val if query: if server_settings is None: From b1a723b17b3b208afeaca7c1723a0e8faddac6cb Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Wed, 19 Mar 2025 21:26:10 -0500 Subject: [PATCH 03/11] Cosolidate if statements --- asyncpg/connect_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 25488003..85301117 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -284,6 +284,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if dsn: parsed = urllib.parse.urlparse(dsn) + + query = None if parsed.query: query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) for key, val in query.items(): @@ -306,7 +308,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, connection_service_file = pathlib.Path(connection_service_file) if connection_service_file is not None and service is not None: - # TODO Open and parse connection service file pg_service = configparser.ConfigParser() pg_service.read(connection_service_file) if service in pg_service.sections(): @@ -396,8 +397,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if gsslib is None: gsslib = val - if dsn: - parsed = urllib.parse.urlparse(dsn) if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( @@ -433,11 +432,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if password is None and dsn_password: password = urllib.parse.unquote(dsn_password) - if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + if query: if 'port' in query: val = query.pop('port') From 1c337cd1e14a9e756fc8752a637e39a1468e9d7c Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Wed, 19 Mar 2025 21:42:59 -0500 Subject: [PATCH 04/11] Fix formatting --- asyncpg/connect_utils.py | 175 +++++++++++++++++++-------------------- tests/test_connect.py | 75 ++++++++--------- 2 files changed, 125 insertions(+), 125 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 85301117..415e5271 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -305,98 +305,97 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: connection_service_file = None else: - connection_service_file = pathlib.Path(connection_service_file) + connection_service_file = pathlib.Path(connection_service_file) if connection_service_file is not None and service is not None: pg_service = configparser.ConfigParser() pg_service.read(connection_service_file) if service in pg_service.sections(): - service_params = pg_service[service] - if 'port' in service_params: - val = service_params.pop('port') - if not port and val: - port = [int(p) for p in val.split(',')] - - if 'host' in service_params: - val = service_params.pop('host') - if not host and val: - host, port = _parse_hostlist(val, port) - - if 'dbname' in service_params: - val = service_params.pop('dbname') - if database is None: - database = val - - if 'database' in service_params: - val = service_params.pop('database') - if database is None: - database = val - - if 'user' in service_params: - val = service_params.pop('user') - if user is None: - user = val - - if 'password' in service_params: - val = service_params.pop('password') - if password is None: - password = val - - if 'passfile' in service_params: - val = service_params.pop('passfile') - if passfile is None: - passfile = val - - if 'sslmode' in service_params: - val = service_params.pop('sslmode') - if ssl is None: - ssl = val - - if 'sslcert' in service_params: - sslcert = service_params.pop('sslcert') - - if 'sslkey' in service_params: - sslkey = service_params.pop('sslkey') - - if 'sslrootcert' in service_params: - sslrootcert = service_params.pop('sslrootcert') - - if 'sslnegotiation' in service_params: - sslnegotiation = service_params.pop('sslnegotiation') - - if 'sslcrl' in service_params: - sslcrl = service_params.pop('sslcrl') - - if 'sslpassword' in service_params: - sslpassword = service_params.pop('sslpassword') - - if 'ssl_min_protocol_version' in service_params: - ssl_min_protocol_version = service_params.pop( - 'ssl_min_protocol_version' - ) - - if 'ssl_max_protocol_version' in service_params: - ssl_max_protocol_version = service_params.pop( - 'ssl_max_protocol_version' - ) - - if 'target_session_attrs' in service_params: - dsn_target_session_attrs = service_params.pop( - 'target_session_attrs' - ) - if target_session_attrs is None: - target_session_attrs = dsn_target_session_attrs - - if 'krbsrvname' in service_params: - val = service_params.pop('krbsrvname') - if krbsrvname is None: - krbsrvname = val - - if 'gsslib' in service_params: - val = service_params.pop('gsslib') - if gsslib is None: - gsslib = val + service_params = pg_service[service] + if 'port' in service_params: + val = service_params.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in service_params: + val = service_params.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in service_params: + val = service_params.pop('dbname') + if database is None: + database = val + + if 'database' in service_params: + val = service_params.pop('database') + if database is None: + database = val + + if 'user' in service_params: + val = service_params.pop('user') + if user is None: + user = val + + if 'password' in service_params: + val = service_params.pop('password') + if password is None: + password = val + + if 'passfile' in service_params: + val = service_params.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in service_params: + val = service_params.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in service_params: + sslcert = service_params.pop('sslcert') + + if 'sslkey' in service_params: + sslkey = service_params.pop('sslkey') + + if 'sslrootcert' in service_params: + sslrootcert = service_params.pop('sslrootcert') + + if 'sslnegotiation' in service_params: + sslnegotiation = service_params.pop('sslnegotiation') + + if 'sslcrl' in service_params: + sslcrl = service_params.pop('sslcrl') + + if 'sslpassword' in service_params: + sslpassword = service_params.pop('sslpassword') + + if 'ssl_min_protocol_version' in service_params: + ssl_min_protocol_version = service_params.pop( + 'ssl_min_protocol_version' + ) + + if 'ssl_max_protocol_version' in service_params: + ssl_max_protocol_version = service_params.pop( + 'ssl_max_protocol_version' + ) + if 'target_session_attrs' in service_params: + dsn_target_session_attrs = service_params.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if 'krbsrvname' in service_params: + val = service_params.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + + if 'gsslib' in service_params: + val = service_params.pop('gsslib') + if gsslib is None: + gsslib = val if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( @@ -609,7 +608,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, database=database, user=user, passfile=passfile) - addrs = [] have_tcp_addrs = False for h, p in zip(host, port): @@ -843,7 +841,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib, service): + target_session_attrs, krbsrvname, gsslib, + service): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', diff --git a/tests/test_connect.py b/tests/test_connect.py index 4ac46578..4a5445ba 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1239,8 +1239,9 @@ def test_connect_params(self): self.run_testcase(testcase) def test_connect_connection_service_file(self): - connection_service_file = tempfile.NamedTemporaryFile('w+t', delete=False) - connection_service_file.write(textwrap.dedent(f''' + connection_service_file = tempfile.NamedTemporaryFile( + 'w+t', delete=False) + connection_service_file.write(textwrap.dedent(''' [test_service_dbname] port=5433 host=somehost @@ -1264,41 +1265,41 @@ def test_connect_connection_service_file(self): connection_service_file.close() os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR) try: - # passfile path in env - self.run_testcase({ - 'dsn': 'postgresql://?service=test_service_dbname', - 'env': { - 'PGSERVICEFILE': connection_service_file.name - }, - 'result': ( - [('somehost', 5433)], - { - 'user': 'admin', - 'password': 'test_password', - 'database': 'test_dbname', - 'target_session_attrs': 'primary', - 'krbsrvname': 'fakekrbsrvname', - 'gsslib': 'sspi', - } - ) - }) - self.run_testcase({ - 'dsn': 'postgresql://?service=test_service_database', - 'env': { - 'PGSERVICEFILE': connection_service_file.name - }, - 'result': ( - [('somehost', 5433)], - { - 'user': 'admin', - 'password': 'test_password', - 'database': 'test_dbname', - 'target_session_attrs': 'primary', - 'krbsrvname': 'fakekrbsrvname', - 'gsslib': 'sspi', - } - ) - }) + # passfile path in env + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_database', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) finally: os.unlink(connection_service_file.name) From 8558a460d66e5a5e027052d555060694d6e61928 Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Thu, 20 Mar 2025 19:05:40 -0500 Subject: [PATCH 05/11] Move con service file parse below query --- asyncpg/connect_utils.py | 180 +++++++++++++++++++-------------------- 1 file changed, 90 insertions(+), 90 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 415e5271..afefc1d2 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -307,96 +307,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: connection_service_file = pathlib.Path(connection_service_file) - if connection_service_file is not None and service is not None: - pg_service = configparser.ConfigParser() - pg_service.read(connection_service_file) - if service in pg_service.sections(): - service_params = pg_service[service] - if 'port' in service_params: - val = service_params.pop('port') - if not port and val: - port = [int(p) for p in val.split(',')] - - if 'host' in service_params: - val = service_params.pop('host') - if not host and val: - host, port = _parse_hostlist(val, port) - - if 'dbname' in service_params: - val = service_params.pop('dbname') - if database is None: - database = val - - if 'database' in service_params: - val = service_params.pop('database') - if database is None: - database = val - - if 'user' in service_params: - val = service_params.pop('user') - if user is None: - user = val - - if 'password' in service_params: - val = service_params.pop('password') - if password is None: - password = val - - if 'passfile' in service_params: - val = service_params.pop('passfile') - if passfile is None: - passfile = val - - if 'sslmode' in service_params: - val = service_params.pop('sslmode') - if ssl is None: - ssl = val - - if 'sslcert' in service_params: - sslcert = service_params.pop('sslcert') - - if 'sslkey' in service_params: - sslkey = service_params.pop('sslkey') - - if 'sslrootcert' in service_params: - sslrootcert = service_params.pop('sslrootcert') - - if 'sslnegotiation' in service_params: - sslnegotiation = service_params.pop('sslnegotiation') - - if 'sslcrl' in service_params: - sslcrl = service_params.pop('sslcrl') - - if 'sslpassword' in service_params: - sslpassword = service_params.pop('sslpassword') - - if 'ssl_min_protocol_version' in service_params: - ssl_min_protocol_version = service_params.pop( - 'ssl_min_protocol_version' - ) - - if 'ssl_max_protocol_version' in service_params: - ssl_max_protocol_version = service_params.pop( - 'ssl_max_protocol_version' - ) - - if 'target_session_attrs' in service_params: - dsn_target_session_attrs = service_params.pop( - 'target_session_attrs' - ) - if target_session_attrs is None: - target_session_attrs = dsn_target_session_attrs - - if 'krbsrvname' in service_params: - val = service_params.pop('krbsrvname') - if krbsrvname is None: - krbsrvname = val - - if 'gsslib' in service_params: - val = service_params.pop('gsslib') - if gsslib is None: - gsslib = val - if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' @@ -529,6 +439,96 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: server_settings = {**query, **server_settings} + if connection_service_file is not None and service is not None: + pg_service = configparser.ConfigParser() + pg_service.read(connection_service_file) + if service in pg_service.sections(): + service_params = pg_service[service] + if 'port' in service_params: + val = service_params.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in service_params: + val = service_params.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in service_params: + val = service_params.pop('dbname') + if database is None: + database = val + + if 'database' in service_params: + val = service_params.pop('database') + if database is None: + database = val + + if 'user' in service_params: + val = service_params.pop('user') + if user is None: + user = val + + if 'password' in service_params: + val = service_params.pop('password') + if password is None: + password = val + + if 'passfile' in service_params: + val = service_params.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in service_params: + val = service_params.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in service_params: + sslcert = service_params.pop('sslcert') + + if 'sslkey' in service_params: + sslkey = service_params.pop('sslkey') + + if 'sslrootcert' in service_params: + sslrootcert = service_params.pop('sslrootcert') + + if 'sslnegotiation' in service_params: + sslnegotiation = service_params.pop('sslnegotiation') + + if 'sslcrl' in service_params: + sslcrl = service_params.pop('sslcrl') + + if 'sslpassword' in service_params: + sslpassword = service_params.pop('sslpassword') + + if 'ssl_min_protocol_version' in service_params: + ssl_min_protocol_version = service_params.pop( + 'ssl_min_protocol_version' + ) + + if 'ssl_max_protocol_version' in service_params: + ssl_max_protocol_version = service_params.pop( + 'ssl_max_protocol_version' + ) + + if 'target_session_attrs' in service_params: + dsn_target_session_attrs = service_params.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if 'krbsrvname' in service_params: + val = service_params.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + + if 'gsslib' in service_params: + val = service_params.pop('gsslib') + if gsslib is None: + gsslib = val + if not host: hostspec = os.environ.get('PGHOST') if hostspec: From b6fb1e4dc3a7b8b2d3c5b228b11ed970e189d98b Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Thu, 20 Mar 2025 19:18:49 -0500 Subject: [PATCH 06/11] Add tests --- tests/test_connect.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 4a5445ba..709394f6 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1265,7 +1265,7 @@ def test_connect_connection_service_file(self): connection_service_file.close() os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR) try: - # passfile path in env + # Test connection service file with dbname self.run_testcase({ 'dsn': 'postgresql://?service=test_service_dbname', 'env': { @@ -1283,6 +1283,7 @@ def test_connect_connection_service_file(self): } ) }) + # Test connection service file with database self.run_testcase({ 'dsn': 'postgresql://?service=test_service_database', 'env': { @@ -1300,6 +1301,43 @@ def test_connect_connection_service_file(self): } ) }) + # Test that envvars are overridden by service file + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname', + 'env': { + 'PGUSER': 'user', + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + # Test that dsn params overwrite service file + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname&dbname=test_dbname_dsn', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname_dsn', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) finally: os.unlink(connection_service_file.name) From 93347b0ab05f56c11cf4226dd0e7090cb1dd4ec2 Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Thu, 20 Mar 2025 19:53:07 -0500 Subject: [PATCH 07/11] Fix format --- tests/test_connect.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 709394f6..8ffaa46e 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1322,7 +1322,9 @@ def test_connect_connection_service_file(self): }) # Test that dsn params overwrite service file self.run_testcase({ - 'dsn': 'postgresql://?service=test_service_dbname&dbname=test_dbname_dsn', + 'dsn': 'postgresql://?service={}&dbname={}'.format( + "test_service_dbname", "test_dbname_dsn" + ), 'env': { 'PGSERVICEFILE': connection_service_file.name }, From c02781c3a32b0caf42491fec153367819c8e4b3c Mon Sep 17 00:00:00 2001 From: Andrew Jackson <46945903+AndrewJackson2020@users.noreply.github.com> Date: Fri, 21 Mar 2025 12:01:15 -0500 Subject: [PATCH 08/11] Update asyncpg/connect_utils.py Co-authored-by: Elvis Pranskevichus --- asyncpg/connect_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index afefc1d2..20c1f524 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -528,7 +528,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, val = service_params.pop('gsslib') if gsslib is None: gsslib = val - + if not service: + service = os.environ.get('PGSERVICE') if not host: hostspec = os.environ.get('PGHOST') if hostspec: From fc212409d3c39b74cc104e626ddaa2b1ebbe63a4 Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Wed, 26 Mar 2025 13:44:20 -0500 Subject: [PATCH 09/11] fix ssl handling --- asyncpg/connect_utils.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 20c1f524..34775af8 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -485,32 +485,48 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ssl = val if 'sslcert' in service_params: - sslcert = service_params.pop('sslcert') + val = service_params.pop('sslcert') + if sslcert is None: + sslcert = val if 'sslkey' in service_params: - sslkey = service_params.pop('sslkey') + val = service_params.pop('sslkey') + if sslkey is None: + sslkey = val if 'sslrootcert' in service_params: - sslrootcert = service_params.pop('sslrootcert') + val = service_params.pop('sslrootcert') + if sslrootcert is None: + sslrootcert = val if 'sslnegotiation' in service_params: - sslnegotiation = service_params.pop('sslnegotiation') + val = service_params.pop('sslnegotiation') + if sslnegotiation is None: + sslnegotiation = val if 'sslcrl' in service_params: - sslcrl = service_params.pop('sslcrl') + val = service_params.pop('sslcrl') + if sslcrl is None: + sslcrl = val if 'sslpassword' in service_params: - sslpassword = service_params.pop('sslpassword') + val = service_params.pop('sslpassword') + if sslpassword is None: + sslpassword = val if 'ssl_min_protocol_version' in service_params: - ssl_min_protocol_version = service_params.pop( + val = service_params.pop( 'ssl_min_protocol_version' ) + if ssl_min_protocol_version is None: + ssl_min_protocol_version = val if 'ssl_max_protocol_version' in service_params: - ssl_max_protocol_version = service_params.pop( + val = service_params.pop( 'ssl_max_protocol_version' ) + if ssl_max_protocol_version is None: + ssl_max_protocol_version = val if 'target_session_attrs' in service_params: dsn_target_session_attrs = service_params.pop( From c99ad223418a25b00b5b74d740f234d0c04844c5 Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Wed, 26 Mar 2025 13:58:00 -0500 Subject: [PATCH 10/11] fix format --- asyncpg/connect_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 34775af8..3e13eb26 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -487,46 +487,46 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'sslcert' in service_params: val = service_params.pop('sslcert') if sslcert is None: - sslcert = val + sslcert = val if 'sslkey' in service_params: val = service_params.pop('sslkey') if sslkey is None: - sslkey = val + sslkey = val if 'sslrootcert' in service_params: val = service_params.pop('sslrootcert') if sslrootcert is None: - sslrootcert = val + sslrootcert = val if 'sslnegotiation' in service_params: val = service_params.pop('sslnegotiation') if sslnegotiation is None: - sslnegotiation = val + sslnegotiation = val if 'sslcrl' in service_params: val = service_params.pop('sslcrl') if sslcrl is None: - sslcrl = val + sslcrl = val if 'sslpassword' in service_params: val = service_params.pop('sslpassword') if sslpassword is None: - sslpassword = val + sslpassword = val if 'ssl_min_protocol_version' in service_params: val = service_params.pop( 'ssl_min_protocol_version' ) if ssl_min_protocol_version is None: - ssl_min_protocol_version = val + ssl_min_protocol_version = val if 'ssl_max_protocol_version' in service_params: val = service_params.pop( 'ssl_max_protocol_version' ) if ssl_max_protocol_version is None: - ssl_max_protocol_version = val + ssl_max_protocol_version = val if 'target_session_attrs' in service_params: dsn_target_session_attrs = service_params.pop( From ae843a2be8fd2ab54764b09c3e40da9dd040f7b1 Mon Sep 17 00:00:00 2001 From: CommanderKeynes Date: Wed, 26 Mar 2025 18:16:07 -0500 Subject: [PATCH 11/11] Add servicefile as parameter --- asyncpg/connect_utils.py | 14 ++++++++++---- asyncpg/connection.py | 9 +++++++++ tests/test_connect.py | 4 +++- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 3e13eb26..f2d6a9fc 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -272,7 +272,8 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, service, + password, passfile, database, ssl, + service, servicefile, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib): # `auth_hosts` is the version of host information for the purposes @@ -297,7 +298,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if not service and val: service = val - connection_service_file = os.getenv('PGSERVICEFILE') + connection_service_file = servicefile + + if connection_service_file is None: + connection_service_file = os.getenv('PGSERVICEFILE') + if connection_service_file is None: homedir = compat.get_pg_home_directory() if homedir: @@ -859,7 +864,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib, - service): + service, servicefile): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -889,7 +894,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib, service=service) + krbsrvname=krbsrvname, gsslib=gsslib, + service=service, servicefile=servicefile) config = _ClientConfiguration( command_timeout=command_timeout, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0ee87861..4e3e5cf1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2075,6 +2075,7 @@ async def connect(dsn=None, *, host=None, port=None, user=None, password=None, passfile=None, service=None, + servicefile=None, database=None, loop=None, timeout=60, @@ -2188,6 +2189,10 @@ async def connect(dsn=None, *, The name of the postgres connection service stored in the postgres connection service file. + :param servicefile: + The location of the connnection service file used to store + connection parameters. + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -2400,6 +2405,9 @@ async def connect(dsn=None, *, .. versionchanged:: 0.30.0 Added the *krbsrvname* and *gsslib* parameters. + .. versionchanged:: 0.31.0 + Added the *servicefile* and *service* parameters. + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context @@ -2434,6 +2442,7 @@ async def connect(dsn=None, *, password=password, passfile=passfile, service=service, + servicefile=servicefile, ssl=ssl, direct_tls=direct_tls, database=database, diff --git a/tests/test_connect.py b/tests/test_connect.py index 8ffaa46e..ac95e314 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1134,6 +1134,7 @@ def run_testcase(self, testcase): krbsrvname = testcase.get('krbsrvname') gsslib = testcase.get('gsslib') service = testcase.get('service') + servicefile = testcase.get('servicefile') expected = testcase.get('result') expected_error = testcase.get('error') @@ -1159,7 +1160,8 @@ def run_testcase(self, testcase): direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib, service=service) + krbsrvname=krbsrvname, gsslib=gsslib, + service=service, servicefile=servicefile) params = { k: v for k, v in params._asdict().items()