diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index ad28b18a..543fd2c4 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -218,7 +218,7 @@ def _get_partitions( connection: Connection, table_name: str, schema: str = None - ) -> List[Dict[str, List[Any]]]: + ) -> Optional[List[str]]: schema = schema or self._get_default_schema_name(connection) query = dedent( f""" @@ -227,6 +227,17 @@ def _get_partitions( ).strip() res = connection.execute(sql.text(query)) partition_names = [desc[0] for desc in res.cursor.description] + data_types = [desc[1] for desc in res.cursor.description] + # Compare the column names and types to the shape of an Iceberg $partitions table + if (partition_names == ['partition', 'record_count', 'file_count', 'total_size', 'data'] + and data_types[0].startswith('row(') + and data_types[1] == 'bigint' + and data_types[2] == 'bigint' + and data_types[3] == 'bigint' + and data_types[4].startswith('row(')): + # This is an Iceberg $partitions table - these match the partition metadata columns + return None + # This is a Hive table - these are the partition names return partition_names def get_pk_constraint(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: @@ -326,7 +337,7 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non try: partitioned_columns = self._get_partitions(connection, f"{table_name}", schema) except Exception as e: - # e.g. it's not a Hive table or an unpartitioned Hive table + # e.g. it's an unpartitioned Hive table logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e) if not partitioned_columns: return []