9
9
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
+ from typing import Any , Dict , Optional
13
+
12
14
from sqlalchemy .sql import compiler
13
15
from sqlalchemy .sql .base import DialectKWArgs
16
+ from sqlalchemy .sql .schema import Table
14
17
15
18
# https://trino.io/docs/current/language/reserved.html
16
19
RESERVED_WORDS = {
92
95
93
96
94
97
class TrinoSQLCompiler (compiler .SQLCompiler ):
95
- def limit_clause (self , select , ** kw ) :
98
+ def limit_clause (self , select : Any , ** kw : Dict [ str , Any ]) -> str :
96
99
"""
97
100
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98
101
"""
@@ -103,15 +106,15 @@ def limit_clause(self, select, **kw):
103
106
text += "\n LIMIT " + self .process (select ._limit_clause , ** kw )
104
107
return text
105
108
106
- def visit_table (self , table , asfrom = False , iscrud = False , ashint = False ,
107
- fromhints = None , use_schema = True , ** kwargs ) :
109
+ def visit_table (self , table : Table , asfrom : bool = False , iscrud : bool = False , ashint : bool = False ,
110
+ fromhints : Optional [ Any ] = None , use_schema : bool = True , ** kwargs : Any ) -> str :
108
111
sql = super (TrinoSQLCompiler , self ).visit_table (
109
112
table , asfrom , iscrud , ashint , fromhints , use_schema , ** kwargs
110
113
)
111
114
return self .add_catalog (sql , table )
112
115
113
116
@staticmethod
114
- def add_catalog (sql , table ) :
117
+ def add_catalog (sql : str , table : Table ) -> str :
115
118
if table is None or not isinstance (table , DialectKWArgs ):
116
119
return sql
117
120
@@ -131,7 +134,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131
134
132
135
133
136
class TrinoTypeCompiler (compiler .GenericTypeCompiler ):
134
- def visit_FLOAT (self , type_ , ** kw ) :
137
+ def visit_FLOAT (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
135
138
precision = type_ .precision or 32
136
139
if 0 <= precision <= 32 :
137
140
return self .visit_REAL (type_ , ** kw )
@@ -140,37 +143,37 @@ def visit_FLOAT(self, type_, **kw):
140
143
else :
141
144
raise ValueError (f"type.precision must be in range [0, 64], got { type_ .precision } " )
142
145
143
- def visit_DOUBLE (self , type_ , ** kw ) :
146
+ def visit_DOUBLE (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
144
147
return "DOUBLE"
145
148
146
- def visit_NUMERIC (self , type_ , ** kw ) :
149
+ def visit_NUMERIC (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
147
150
return self .visit_DECIMAL (type_ , ** kw )
148
151
149
- def visit_NCHAR (self , type_ , ** kw ) :
152
+ def visit_NCHAR (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
150
153
return self .visit_CHAR (type_ , ** kw )
151
154
152
- def visit_NVARCHAR (self , type_ , ** kw ) :
155
+ def visit_NVARCHAR (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
153
156
return self .visit_VARCHAR (type_ , ** kw )
154
157
155
- def visit_TEXT (self , type_ , ** kw ) :
158
+ def visit_TEXT (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
156
159
return self .visit_VARCHAR (type_ , ** kw )
157
160
158
- def visit_BINARY (self , type_ , ** kw ) :
161
+ def visit_BINARY (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
159
162
return self .visit_VARBINARY (type_ , ** kw )
160
163
161
- def visit_CLOB (self , type_ , ** kw ) :
164
+ def visit_CLOB (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
162
165
return self .visit_VARCHAR (type_ , ** kw )
163
166
164
- def visit_NCLOB (self , type_ , ** kw ) :
167
+ def visit_NCLOB (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
165
168
return self .visit_VARCHAR (type_ , ** kw )
166
169
167
- def visit_BLOB (self , type_ , ** kw ) :
170
+ def visit_BLOB (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
168
171
return self .visit_VARBINARY (type_ , ** kw )
169
172
170
- def visit_DATETIME (self , type_ , ** kw ) :
173
+ def visit_DATETIME (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
171
174
return self .visit_TIMESTAMP (type_ , ** kw )
172
175
173
- def visit_TIMESTAMP (self , type_ , ** kw ) :
176
+ def visit_TIMESTAMP (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
174
177
datatype = "TIMESTAMP"
175
178
precision = getattr (type_ , "precision" , None )
176
179
if precision not in range (0 , 13 ) and precision is not None :
@@ -182,7 +185,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182
185
183
186
return datatype
184
187
185
- def visit_TIME (self , type_ , ** kw ) :
188
+ def visit_TIME (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
186
189
datatype = "TIME"
187
190
precision = getattr (type_ , "precision" , None )
188
191
if precision not in range (0 , 13 ) and precision is not None :
@@ -193,13 +196,13 @@ def visit_TIME(self, type_, **kw):
193
196
datatype += " WITH TIME ZONE"
194
197
return datatype
195
198
196
- def visit_JSON (self , type_ , ** kw ) :
199
+ def visit_JSON (self , type_ : Any , ** kw : Dict [ str , Any ]) -> str :
197
200
return 'JSON'
198
201
199
202
200
203
class TrinoIdentifierPreparer (compiler .IdentifierPreparer ):
201
204
reserved_words = RESERVED_WORDS
202
205
203
- def format_table (self , table , use_schema = True , name = None ):
206
+ def format_table (self , table : Table , use_schema : bool = True , name : Optional [ str ] = None ) -> str :
204
207
result = super (TrinoIdentifierPreparer , self ).format_table (table , use_schema , name )
205
208
return TrinoSQLCompiler .add_catalog (result , table )
0 commit comments