9
9
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
+ from typing import Any , Optional
13
+
12
14
from sqlalchemy .sql import compiler
13
15
from sqlalchemy .sql .base import DialectKWArgs
14
16
92
94
93
95
94
96
class TrinoSQLCompiler (compiler .SQLCompiler ):
95
- def limit_clause (self , select , ** kw ) :
97
+ def limit_clause (self , select : Any , ** kw : Any ) -> str :
96
98
"""
97
99
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98
100
"""
@@ -103,15 +105,15 @@ def limit_clause(self, select, **kw):
103
105
text += "\n LIMIT " + self .process (select ._limit_clause , ** kw )
104
106
return text
105
107
106
- def visit_table (self , table , asfrom = False , iscrud = False , ashint = False ,
107
- fromhints = None , use_schema = True , ** kwargs ) :
108
+ def visit_table (self , table : Any , asfrom : bool = False , iscrud : bool = False , ashint : bool = False ,
109
+ fromhints : Optional [ Any ] = None , use_schema : bool = True , ** kwargs : Any ) -> str :
108
110
sql = super (TrinoSQLCompiler , self ).visit_table (
109
111
table , asfrom , iscrud , ashint , fromhints , use_schema , ** kwargs
110
112
)
111
113
return self .add_catalog (sql , table )
112
114
113
115
@staticmethod
114
- def add_catalog (sql , table ) :
116
+ def add_catalog (sql : str , table : Any ) -> str :
115
117
if table is None or not isinstance (table , DialectKWArgs ):
116
118
return sql
117
119
@@ -131,7 +133,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131
133
132
134
133
135
class TrinoTypeCompiler (compiler .GenericTypeCompiler ):
134
- def visit_FLOAT (self , type_ , ** kw ) :
136
+ def visit_FLOAT (self , type_ : Any , ** kw : Any ) -> str :
135
137
precision = type_ .precision or 32
136
138
if 0 <= precision <= 32 :
137
139
return self .visit_REAL (type_ , ** kw )
@@ -140,37 +142,37 @@ def visit_FLOAT(self, type_, **kw):
140
142
else :
141
143
raise ValueError (f"type.precision must be in range [0, 64], got { type_ .precision } " )
142
144
143
- def visit_DOUBLE (self , type_ , ** kw ) :
145
+ def visit_DOUBLE (self , type_ : Any , ** kw : Any ) -> str :
144
146
return "DOUBLE"
145
147
146
- def visit_NUMERIC (self , type_ , ** kw ) :
148
+ def visit_NUMERIC (self , type_ : Any , ** kw : Any ) -> str :
147
149
return self .visit_DECIMAL (type_ , ** kw )
148
150
149
- def visit_NCHAR (self , type_ , ** kw ) :
151
+ def visit_NCHAR (self , type_ : Any , ** kw : Any ) -> str :
150
152
return self .visit_CHAR (type_ , ** kw )
151
153
152
- def visit_NVARCHAR (self , type_ , ** kw ) :
154
+ def visit_NVARCHAR (self , type_ : Any , ** kw : Any ) -> str :
153
155
return self .visit_VARCHAR (type_ , ** kw )
154
156
155
- def visit_TEXT (self , type_ , ** kw ) :
157
+ def visit_TEXT (self , type_ : Any , ** kw : Any ) -> str :
156
158
return self .visit_VARCHAR (type_ , ** kw )
157
159
158
- def visit_BINARY (self , type_ , ** kw ) :
160
+ def visit_BINARY (self , type_ : Any , ** kw : Any ) -> str :
159
161
return self .visit_VARBINARY (type_ , ** kw )
160
162
161
- def visit_CLOB (self , type_ , ** kw ) :
163
+ def visit_CLOB (self , type_ : Any , ** kw : Any ) -> str :
162
164
return self .visit_VARCHAR (type_ , ** kw )
163
165
164
- def visit_NCLOB (self , type_ , ** kw ) :
166
+ def visit_NCLOB (self , type_ : Any , ** kw : Any ) -> str :
165
167
return self .visit_VARCHAR (type_ , ** kw )
166
168
167
- def visit_BLOB (self , type_ , ** kw ) :
169
+ def visit_BLOB (self , type_ : Any , ** kw : Any ) -> str :
168
170
return self .visit_VARBINARY (type_ , ** kw )
169
171
170
- def visit_DATETIME (self , type_ , ** kw ) :
172
+ def visit_DATETIME (self , type_ : Any , ** kw : Any ) -> str :
171
173
return self .visit_TIMESTAMP (type_ , ** kw )
172
174
173
- def visit_TIMESTAMP (self , type_ , ** kw ) :
175
+ def visit_TIMESTAMP (self , type_ : Any , ** kw : Any ) -> str :
174
176
datatype = "TIMESTAMP"
175
177
precision = getattr (type_ , "precision" , None )
176
178
if precision not in range (0 , 13 ) and precision is not None :
@@ -182,7 +184,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182
184
183
185
return datatype
184
186
185
- def visit_TIME (self , type_ , ** kw ) :
187
+ def visit_TIME (self , type_ : Any , ** kw : Any ) -> str :
186
188
datatype = "TIME"
187
189
precision = getattr (type_ , "precision" , None )
188
190
if precision not in range (0 , 13 ) and precision is not None :
@@ -193,13 +195,13 @@ def visit_TIME(self, type_, **kw):
193
195
datatype += " WITH TIME ZONE"
194
196
return datatype
195
197
196
- def visit_JSON (self , type_ , ** kw ) :
198
+ def visit_JSON (self , type_ : Any , ** kw : Any ) -> str :
197
199
return 'JSON'
198
200
199
201
200
202
class TrinoIdentifierPreparer (compiler .IdentifierPreparer ):
201
203
reserved_words = RESERVED_WORDS
202
204
203
- def format_table (self , table , use_schema = True , name = None ):
205
+ def format_table (self , table : Any , use_schema : bool = True , name : Optional [ str ] = None ) -> str :
204
206
result = super (TrinoIdentifierPreparer , self ).format_table (table , use_schema , name )
205
207
return TrinoSQLCompiler .add_catalog (result , table )
0 commit comments