Skip to content

Commit ce60a48

Browse files
mdesmethashhar
authored andcommitted
Fix 'Engine' object has no attribute 'connection'
1 parent 196188e commit ce60a48

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

Diff for: tests/integration/test_sqlalchemy_integration.py

+24
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,27 @@ def test_json_column(trino_connection, json_object):
344344
assert rows[0] == (1, json_object)
345345
finally:
346346
metadata.drop_all(engine)
347+
348+
349+
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
350+
def test_get_table_comment(trino_connection):
351+
engine, conn = trino_connection
352+
353+
if not engine.dialect.has_schema(engine, "test"):
354+
engine.execute(sqla.schema.CreateSchema("test"))
355+
metadata = sqla.MetaData()
356+
357+
try:
358+
sqla.Table(
359+
'table_with_id',
360+
metadata,
361+
sqla.Column('id', sqla.Integer),
362+
schema="test",
363+
# comment="This is a comment" TODO: Support comment creation through sqlalchemy api
364+
)
365+
metadata.create_all(engine)
366+
insp = sqla.inspect(engine)
367+
actual = insp.get_table_comment(table_name='table_with_id', schema="test")
368+
assert actual['text'] is None
369+
finally:
370+
metadata.drop_all(engine)

Diff for: trino/sqlalchemy/dialect.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# limitations under the License.
1212
import json
1313
from textwrap import dedent
14-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
14+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
1515
from urllib.parse import unquote_plus
1616

1717
from sqlalchemy import exc, sql
18+
from sqlalchemy.engine import Engine
1819
from sqlalchemy.engine.base import Connection
1920
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
2021
from sqlalchemy.engine.url import URL
@@ -340,12 +341,17 @@ def _get_server_version_info(self, connection: Connection) -> Any:
340341
logger.debug(f"Failed to get server version: {e.orig.message}")
341342
return None
342343

344+
def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection:
345+
if isinstance(connection, Engine):
346+
return connection.raw_connection()
347+
return connection.connection
348+
343349
def _get_default_catalog_name(self, connection: Connection) -> Optional[str]:
344-
dbapi_connection: trino_dbapi.Connection = connection.connection
350+
dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection)
345351
return dbapi_connection.catalog
346352

347353
def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
348-
dbapi_connection: trino_dbapi.Connection = connection.connection
354+
dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection)
349355
return dbapi_connection.schema
350356

351357
def do_execute(

0 commit comments

Comments
 (0)