|
11 | 11 | # limitations under the License.
|
12 | 12 | import json
|
13 | 13 | from textwrap import dedent
|
14 |
| -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple |
| 14 | +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union |
15 | 15 | from urllib.parse import unquote_plus
|
16 | 16 |
|
17 | 17 | from sqlalchemy import exc, sql
|
| 18 | +from sqlalchemy.engine import Engine |
18 | 19 | from sqlalchemy.engine.base import Connection
|
19 | 20 | from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
|
20 | 21 | from sqlalchemy.engine.url import URL
|
@@ -340,12 +341,17 @@ def _get_server_version_info(self, connection: Connection) -> Any:
|
340 | 341 | logger.debug(f"Failed to get server version: {e.orig.message}")
|
341 | 342 | return None
|
342 | 343 |
|
| 344 | + def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection: |
| 345 | + if isinstance(connection, Engine): |
| 346 | + return connection.raw_connection() |
| 347 | + return connection.connection |
| 348 | + |
343 | 349 | def _get_default_catalog_name(self, connection: Connection) -> Optional[str]:
|
344 |
| - dbapi_connection: trino_dbapi.Connection = connection.connection |
| 350 | + dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection) |
345 | 351 | return dbapi_connection.catalog
|
346 | 352 |
|
347 | 353 | def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
|
348 |
| - dbapi_connection: trino_dbapi.Connection = connection.connection |
| 354 | + dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection) |
349 | 355 | return dbapi_connection.schema
|
350 | 356 |
|
351 | 357 | def do_execute(
|
|
0 commit comments