From ca7186a052c948d53f6440e679e8fc044655d779 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Sun, 3 Mar 2024 15:33:39 +0500 Subject: [PATCH] S01E09 --- databases/backends/aiopg.py | 6 +- databases/backends/compilers/__init__.py | 0 databases/backends/compilers/psycopg.py | 17 ----- databases/backends/dialects/__init__.py | 0 databases/backends/dialects/psycopg.py | 46 ------------- databases/backends/psycopg.py | 15 +++-- tests/test_databases.py | 86 +++++++++--------------- 7 files changed, 41 insertions(+), 129 deletions(-) delete mode 100644 databases/backends/compilers/__init__.py delete mode 100644 databases/backends/compilers/psycopg.py delete mode 100644 databases/backends/dialects/__init__.py delete mode 100644 databases/backends/dialects/psycopg.py diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 0b4d95a3..1df30699 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -5,15 +5,13 @@ import uuid import aiopg +from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg from sqlalchemy.engine.cursor import CursorResultMetaData from sqlalchemy.engine.interfaces import Dialect, ExecutionContext -from sqlalchemy.engine.row import Row from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement from databases.backends.common.records import Record, Row, create_column_maps -from databases.backends.compilers.psycopg import PGCompiler_psycopg -from databases.backends.dialects.psycopg import PGDialect_psycopg from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -38,12 +36,10 @@ def _get_dialect(self) -> Dialect: dialect = PGDialect_psycopg( json_serializer=json.dumps, json_deserializer=lambda x: x ) - dialect.statement_compiler = PGCompiler_psycopg dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ dialect._backslash_escapes = False - dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ dialect._has_native_hstore = True dialect.supports_native_decimal = True diff --git a/databases/backends/compilers/__init__.py b/databases/backends/compilers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/databases/backends/compilers/psycopg.py b/databases/backends/compilers/psycopg.py deleted file mode 100644 index 654c22a1..00000000 --- a/databases/backends/compilers/psycopg.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy.dialects.postgresql.psycopg import PGCompiler_psycopg - - -class APGCompiler_psycopg2(PGCompiler_psycopg): - def construct_params(self, *args, **kwargs): - pd = super().construct_params(*args, **kwargs) - - for column in self.prefetch: - pd[column.key] = self._exec_default(column.default) - - return pd - - def _exec_default(self, default): - if default.is_callable: - return default.arg(self.dialect) - else: - return default.arg diff --git a/databases/backends/dialects/__init__.py b/databases/backends/dialects/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py deleted file mode 100644 index 07bd1880..00000000 --- a/databases/backends/dialects/psycopg.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -All the unique changes for the databases package -with the custom Numeric as the deprecated pypostgresql -for backwards compatibility and to make sure the -package can go to SQLAlchemy 2.0+. -""" - -import typing - -from sqlalchemy import types, util -from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext -from sqlalchemy.engine import processors -from sqlalchemy.types import Float, Numeric - - -class PGExecutionContext_psycopg(PGExecutionContext): - ... - - -class PGNumeric(Numeric): - def bind_processor( - self, dialect: typing.Any - ) -> typing.Union[str, None]: # pragma: no cover - return processors.to_str - - def result_processor( - self, dialect: typing.Any, coltype: typing.Any - ) -> typing.Union[float, None]: # pragma: no cover - if self.asdecimal: - return None - else: - return processors.to_float - - -class PGDialect_psycopg(PGDialect): - colspecs = util.update_copy( - PGDialect.colspecs, - { - types.Numeric: PGNumeric, - types.Float: Float, - }, - ) - execution_ctx_cls = PGExecutionContext_psycopg - - -dialect = PGDialect_psycopg diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py index f83f4917..da0a6718 100644 --- a/databases/backends/psycopg.py +++ b/databases/backends/psycopg.py @@ -22,7 +22,7 @@ class PsycopgBackend(DatabaseBackend): _database_url: DatabaseURL _options: typing.Dict[str, typing.Any] _dialect: Dialect - _pool: typing.Optional[psycopg_pool.AsyncConnectionPool] + _pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None def __init__( self, @@ -33,7 +33,6 @@ def __init__( self._options = options self._dialect = PGDialect_psycopg() self._dialect.implicit_returning = True - self._pool = None async def connect(self) -> None: if self._pool is not None: @@ -95,7 +94,10 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: rows = await cursor.fetchall() column_maps = create_column_maps(result_columns) - return [PsycopgRecord(row, result_columns, self._dialect, column_maps) for row in rows] + return [ + PsycopgRecord(row, result_columns, self._dialect, column_maps) + for row in rows + ] async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: if self._connection is None: @@ -167,7 +169,8 @@ def raw_connection(self) -> typing.Any: return self._connection def _compile( - self, query: ClauseElement, + self, + query: ClauseElement, ) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]: compiled = query.compile( dialect=self._dialect, @@ -224,7 +227,9 @@ def _mapping(self) -> typing.Mapping: def __getitem__(self, key: typing.Any) -> typing.Any: if len(self._column_map) == 0: - return self._mapping[key] + if isinstance(key, str): + return self._mapping[key] + return self._row[key] elif isinstance(key, Column): idx, datatype = self._column_map_full[str(key)] elif isinstance(key, int): diff --git a/tests/test_databases.py b/tests/test_databases.py index 5c0b61d1..66164aea 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -204,17 +204,17 @@ async def test_queries(database_url): assert len(results) == 3 assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True assert results[1]["text"] == "example2" - assert results[1]["completed"] == False + assert results[1]["completed"] is False assert results[2]["text"] == "example3" - assert results[2]["completed"] == True + assert results[2]["completed"] is True # fetch_one() query = notes.select() result = await database.fetch_one(query=query) assert result["text"] == "example1" - assert result["completed"] == True + assert result["completed"] is True # fetch_val() query = sqlalchemy.sql.select(*[notes.c.text]) @@ -246,11 +246,11 @@ async def test_queries(database_url): iterate_results.append(result) assert len(iterate_results) == 3 assert iterate_results[0]["text"] == "example1" - assert iterate_results[0]["completed"] == True + assert iterate_results[0]["completed"] is True assert iterate_results[1]["text"] == "example2" - assert iterate_results[1]["completed"] == False + assert iterate_results[1]["completed"] is False assert iterate_results[2]["text"] == "example3" - assert iterate_results[2]["completed"] == True + assert iterate_results[2]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -280,26 +280,26 @@ async def test_queries_raw(database_url): results = await database.fetch_all(query=query, values={"completed": True}) assert len(results) == 2 assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True assert results[1]["text"] == "example3" - assert results[1]["completed"] == True + assert results[1]["completed"] is True # fetch_one() query = "SELECT * FROM notes WHERE completed = :completed" result = await database.fetch_one(query=query, values={"completed": False}) assert result["text"] == "example2" - assert result["completed"] == False + assert result["completed"] is False # fetch_val() query = "SELECT completed FROM notes WHERE text = :text" result = await database.fetch_val(query=query, values={"text": "example1"}) - assert result == True + assert result is True query = "SELECT * FROM notes WHERE text = :text" result = await database.fetch_val( query=query, values={"text": "example1"}, column="completed" ) - assert result == True + assert result is True # iterate() query = "SELECT * FROM notes" @@ -308,11 +308,11 @@ async def test_queries_raw(database_url): iterate_results.append(result) assert len(iterate_results) == 3 assert iterate_results[0]["text"] == "example1" - assert iterate_results[0]["completed"] == True + assert iterate_results[0]["completed"] is True assert iterate_results[1]["text"] == "example2" - assert iterate_results[1]["completed"] == False + assert iterate_results[1]["completed"] is False assert iterate_results[2]["text"] == "example3" - assert iterate_results[2]["completed"] == True + assert iterate_results[2]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -380,7 +380,7 @@ async def test_results_support_mapping_interface(database_url): assert isinstance(results_as_dicts[0]["id"], int) assert results_as_dicts[0]["text"] == "example1" - assert results_as_dicts[0]["completed"] == True + assert results_as_dicts[0]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -467,7 +467,7 @@ async def test_execute_return_val(database_url): query = notes.select().where(notes.c.id == pk) result = await database.fetch_one(query) assert result["text"] == "example1" - assert result["completed"] == True + assert result["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -857,7 +857,7 @@ async def test_transaction_commit_low_level(database_url): try: query = notes.insert().values(text="example1", completed=True) await database.execute(query) - except: # pragma: no cover + except Exception: # pragma: no cover await transaction.rollback() else: await transaction.commit() @@ -881,7 +881,7 @@ async def test_transaction_rollback_low_level(database_url): query = notes.insert().values(text="example1", completed=True) await database.execute(query) raise RuntimeError() - except: + except Exception: await transaction.rollback() else: # pragma: no cover await transaction.commit() @@ -1354,13 +1354,12 @@ async def test_queries_with_expose_backend_connection(database_url): ]: cursor = await raw_connection.cursor() await cursor.execute(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(insert_query, values) elif database.url.scheme in [ "postgresql", "postgresql+asyncpg", - "postgresql+psycopg", ]: await raw_connection.execute(insert_query, *values) elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: @@ -1372,7 +1371,7 @@ async def test_queries_with_expose_backend_connection(database_url): if database.url.scheme in ["mysql", "mysql+aiomysql"]: cursor = await raw_connection.cursor() await cursor.executemany(insert_query, values) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.executemany(insert_query, values) elif database.url.scheme == "postgresql+aiopg": @@ -1395,15 +1394,11 @@ async def test_queries_with_expose_backend_connection(database_url): cursor = await raw_connection.cursor() await cursor.execute(select_query) results = await cursor.fetchall() - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(select_query) results = await cursor.fetchall() - elif database.url.scheme in [ - "postgresql", - "postgresql+asyncpg", - "postgresql+psycopg", - ]: + elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]: results = await raw_connection.fetch(select_query) elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]: results = await raw_connection.execute_fetchall(select_query) @@ -1411,20 +1406,16 @@ async def test_queries_with_expose_backend_connection(database_url): assert len(results) == 3 # Raw output for the raw request assert results[0][1] == "example1" - assert results[0][2] == True + assert results[0][2] is True assert results[1][1] == "example2" - assert results[1][2] == False + assert results[1][2] is False assert results[2][1] == "example3" - assert results[2][2] == True + assert results[2][2] is True # fetch_one() - if database.url.scheme in [ - "postgresql", - "postgresql+asyncpg", - "postgresql+psycopg", - ]: + if database.url.scheme in ["postgresql", "postgresql+asyncpg"]: result = await raw_connection.fetchrow(select_query) - elif database.url.scheme == "mysql+asyncmy": + elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]: async with raw_connection.cursor() as cursor: await cursor.execute(select_query) result = await cursor.fetchone() @@ -1435,7 +1426,7 @@ async def test_queries_with_expose_backend_connection(database_url): # Raw output for the raw request assert result[1] == "example1" - assert result[2] == True + assert result[2] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1606,7 +1597,7 @@ async def test_column_names(database_url, select_query): assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" - assert results[0]["completed"] == True + assert results[0]["completed"] is True @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -1641,23 +1632,6 @@ async def test_result_named_access(database_url): assert result.completed is True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_mapping_property_interface(database_url): - """ - Test that all connections implement interface with `_mapping` property - """ - async with Database(database_url) as database: - query = notes.select() - single_result = await database.fetch_one(query=query) - assert single_result._mapping["text"] == "example1" - assert single_result._mapping["completed"] is True - - list_result = await database.fetch_all(query=query) - assert list_result[0]._mapping["text"] == "example1" - assert list_result[0]._mapping["completed"] is True - - @async_adapter async def test_should_not_maintain_ref_when_no_cache_param(): async with Database(