Skip to content

Commit 87e70bb

Browse files
committed
Add type hints for sqlalchemy
1 parent ea55fc0 commit 87e70bb

File tree

4 files changed

+191
-71
lines changed

4 files changed

+191
-71
lines changed

trino/sqlalchemy/compiler.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
from typing import Any, Optional
13+
1214
from sqlalchemy.sql import compiler
1315
from sqlalchemy.sql.base import DialectKWArgs
1416

@@ -92,7 +94,7 @@
9294

9395

9496
class TrinoSQLCompiler(compiler.SQLCompiler):
95-
def limit_clause(self, select, **kw):
97+
def limit_clause(self, select: Any, **kw: Any) -> str:
9698
"""
9799
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98100
"""
@@ -103,15 +105,15 @@ def limit_clause(self, select, **kw):
103105
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
104106
return text
105107

106-
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
107-
fromhints=None, use_schema=True, **kwargs):
108+
def visit_table(self, table: Any, asfrom: bool = False, iscrud: bool = False, ashint: bool = False,
109+
fromhints: Optional[Any] = None, use_schema: bool = True, **kwargs: Any) -> str:
108110
sql = super(TrinoSQLCompiler, self).visit_table(
109111
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
110112
)
111113
return self.add_catalog(sql, table)
112114

113115
@staticmethod
114-
def add_catalog(sql, table):
116+
def add_catalog(sql: str, table: Any) -> str:
115117
if table is None or not isinstance(table, DialectKWArgs):
116118
return sql
117119

@@ -131,7 +133,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131133

132134

133135
class TrinoTypeCompiler(compiler.GenericTypeCompiler):
134-
def visit_FLOAT(self, type_, **kw):
136+
def visit_FLOAT(self, type_: Any, **kw: Any) -> str:
135137
precision = type_.precision or 32
136138
if 0 <= precision <= 32:
137139
return self.visit_REAL(type_, **kw)
@@ -140,37 +142,37 @@ def visit_FLOAT(self, type_, **kw):
140142
else:
141143
raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}")
142144

143-
def visit_DOUBLE(self, type_, **kw):
145+
def visit_DOUBLE(self, type_: Any, **kw: Any) -> str:
144146
return "DOUBLE"
145147

146-
def visit_NUMERIC(self, type_, **kw):
148+
def visit_NUMERIC(self, type_: Any, **kw: Any) -> str:
147149
return self.visit_DECIMAL(type_, **kw)
148150

149-
def visit_NCHAR(self, type_, **kw):
151+
def visit_NCHAR(self, type_: Any, **kw: Any) -> str:
150152
return self.visit_CHAR(type_, **kw)
151153

152-
def visit_NVARCHAR(self, type_, **kw):
154+
def visit_NVARCHAR(self, type_: Any, **kw: Any) -> str:
153155
return self.visit_VARCHAR(type_, **kw)
154156

155-
def visit_TEXT(self, type_, **kw):
157+
def visit_TEXT(self, type_: Any, **kw: Any) -> str:
156158
return self.visit_VARCHAR(type_, **kw)
157159

158-
def visit_BINARY(self, type_, **kw):
160+
def visit_BINARY(self, type_: Any, **kw: Any) -> str:
159161
return self.visit_VARBINARY(type_, **kw)
160162

161-
def visit_CLOB(self, type_, **kw):
163+
def visit_CLOB(self, type_: Any, **kw: Any) -> str:
162164
return self.visit_VARCHAR(type_, **kw)
163165

164-
def visit_NCLOB(self, type_, **kw):
166+
def visit_NCLOB(self, type_: Any, **kw: Any) -> str:
165167
return self.visit_VARCHAR(type_, **kw)
166168

167-
def visit_BLOB(self, type_, **kw):
169+
def visit_BLOB(self, type_: Any, **kw: Any) -> str:
168170
return self.visit_VARBINARY(type_, **kw)
169171

170-
def visit_DATETIME(self, type_, **kw):
172+
def visit_DATETIME(self, type_: Any, **kw: Any) -> str:
171173
return self.visit_TIMESTAMP(type_, **kw)
172174

173-
def visit_TIMESTAMP(self, type_, **kw):
175+
def visit_TIMESTAMP(self, type_: Any, **kw: Any) -> str:
174176
datatype = "TIMESTAMP"
175177
precision = getattr(type_, "precision", None)
176178
if precision not in range(0, 13) and precision is not None:
@@ -182,7 +184,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182184

183185
return datatype
184186

185-
def visit_TIME(self, type_, **kw):
187+
def visit_TIME(self, type_: Any, **kw: Any) -> str:
186188
datatype = "TIME"
187189
precision = getattr(type_, "precision", None)
188190
if precision not in range(0, 13) and precision is not None:
@@ -193,13 +195,13 @@ def visit_TIME(self, type_, **kw):
193195
datatype += " WITH TIME ZONE"
194196
return datatype
195197

196-
def visit_JSON(self, type_, **kw):
198+
def visit_JSON(self, type_: Any, **kw: Any) -> str:
197199
return 'JSON'
198200

199201

200202
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
201203
reserved_words = RESERVED_WORDS
202204

203-
def format_table(self, table, use_schema=True, name=None):
205+
def format_table(self, table: Any, use_schema: bool = True, name: Optional[str] = None) -> str:
204206
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
205207
return TrinoSQLCompiler.add_catalog(result, table)

trino/sqlalchemy/datatype.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
# limitations under the License.
1212
import json
1313
import re
14-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
14+
from typing import Any, Dict, Iterator, List, Optional
15+
from typing import Text as typing_Text
16+
from typing import Tuple, Type, TypeVar, Union
1517

1618
from sqlalchemy import util
19+
from sqlalchemy.engine.interfaces import Dialect
1720
from sqlalchemy.sql import sqltypes
1821
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
1922
from sqlalchemy.types import String
2023

2124
SQLType = Union[TypeEngine, Type[TypeEngine]]
25+
_T = TypeVar('_T')
2226

2327

2428
class DOUBLE(sqltypes.Float):
@@ -38,7 +42,7 @@ def __init__(self, key_type: SQLType, value_type: SQLType):
3842
self.value_type: TypeEngine = value_type
3943

4044
@property
41-
def python_type(self):
45+
def python_type(self) -> type:
4246
return dict
4347

4448

@@ -53,36 +57,36 @@ def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]):
5357
self.attr_types.append((attr_name, attr_type))
5458

5559
@property
56-
def python_type(self):
60+
def python_type(self) -> type:
5761
return list
5862

5963

6064
class TIME(sqltypes.TIME):
6165
__visit_name__ = "TIME"
6266

63-
def __init__(self, precision=None, timezone=False):
67+
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
6468
super(TIME, self).__init__(timezone=timezone)
6569
self.precision = precision
6670

6771

6872
class TIMESTAMP(sqltypes.TIMESTAMP):
6973
__visit_name__ = "TIMESTAMP"
7074

71-
def __init__(self, precision=None, timezone=False):
75+
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
7276
super(TIMESTAMP, self).__init__(timezone=timezone)
7377
self.precision = precision
7478

7579

7680
class JSON(TypeDecorator):
7781
impl = String
7882

79-
def process_bind_param(self, value, dialect):
83+
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]:
8084
return json.dumps(value)
8185

82-
def process_result_value(self, value, dialect):
86+
def process_result_value(self, value: Union[str, bytes], dialect: Dialect) -> Optional[_T]:
8387
return json.loads(value)
8488

85-
def get_col_spec(self, **kw):
89+
def get_col_spec(self, **kw: Any) -> str:
8690
return 'JSON'
8791

8892

0 commit comments

Comments
 (0)