Skip to content

Commit 76628b3

Browse files
committed
Add type hints for sqlalchemy
1 parent ea55fc0 commit 76628b3

File tree

4 files changed

+192
-71
lines changed

4 files changed

+192
-71
lines changed

trino/sqlalchemy/compiler.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
from typing import Any, Dict, Optional
13+
1214
from sqlalchemy.sql import compiler
1315
from sqlalchemy.sql.base import DialectKWArgs
16+
from sqlalchemy.sql.schema import Table
1417

1518
# https://trino.io/docs/current/language/reserved.html
1619
RESERVED_WORDS = {
@@ -92,7 +95,7 @@
9295

9396

9497
class TrinoSQLCompiler(compiler.SQLCompiler):
95-
def limit_clause(self, select, **kw):
98+
def limit_clause(self, select: Any, **kw: Dict[str, Any]) -> str:
9699
"""
97100
Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98101
"""
@@ -103,15 +106,15 @@ def limit_clause(self, select, **kw):
103106
text += "\nLIMIT " + self.process(select._limit_clause, **kw)
104107
return text
105108

106-
def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
107-
fromhints=None, use_schema=True, **kwargs):
109+
def visit_table(self, table: Table, asfrom: bool = False, iscrud: bool = False, ashint: bool = False,
110+
fromhints: Optional[Any] = None, use_schema: bool = True, **kwargs: Any) -> str:
108111
sql = super(TrinoSQLCompiler, self).visit_table(
109112
table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs
110113
)
111114
return self.add_catalog(sql, table)
112115

113116
@staticmethod
114-
def add_catalog(sql, table):
117+
def add_catalog(sql: str, table: Table) -> str:
115118
if table is None or not isinstance(table, DialectKWArgs):
116119
return sql
117120

@@ -131,7 +134,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131134

132135

133136
class TrinoTypeCompiler(compiler.GenericTypeCompiler):
134-
def visit_FLOAT(self, type_, **kw):
137+
def visit_FLOAT(self, type_: Any, **kw: Dict[str, Any]) -> str:
135138
precision = type_.precision or 32
136139
if 0 <= precision <= 32:
137140
return self.visit_REAL(type_, **kw)
@@ -140,37 +143,37 @@ def visit_FLOAT(self, type_, **kw):
140143
else:
141144
raise ValueError(f"type.precision must be in range [0, 64], got {type_.precision}")
142145

143-
def visit_DOUBLE(self, type_, **kw):
146+
def visit_DOUBLE(self, type_: Any, **kw: Dict[str, Any]) -> str:
144147
return "DOUBLE"
145148

146-
def visit_NUMERIC(self, type_, **kw):
149+
def visit_NUMERIC(self, type_: Any, **kw: Dict[str, Any]) -> str:
147150
return self.visit_DECIMAL(type_, **kw)
148151

149-
def visit_NCHAR(self, type_, **kw):
152+
def visit_NCHAR(self, type_: Any, **kw: Dict[str, Any]) -> str:
150153
return self.visit_CHAR(type_, **kw)
151154

152-
def visit_NVARCHAR(self, type_, **kw):
155+
def visit_NVARCHAR(self, type_: Any, **kw: Dict[str, Any]) -> str:
153156
return self.visit_VARCHAR(type_, **kw)
154157

155-
def visit_TEXT(self, type_, **kw):
158+
def visit_TEXT(self, type_: Any, **kw: Dict[str, Any]) -> str:
156159
return self.visit_VARCHAR(type_, **kw)
157160

158-
def visit_BINARY(self, type_, **kw):
161+
def visit_BINARY(self, type_: Any, **kw: Dict[str, Any]) -> str:
159162
return self.visit_VARBINARY(type_, **kw)
160163

161-
def visit_CLOB(self, type_, **kw):
164+
def visit_CLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
162165
return self.visit_VARCHAR(type_, **kw)
163166

164-
def visit_NCLOB(self, type_, **kw):
167+
def visit_NCLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
165168
return self.visit_VARCHAR(type_, **kw)
166169

167-
def visit_BLOB(self, type_, **kw):
170+
def visit_BLOB(self, type_: Any, **kw: Dict[str, Any]) -> str:
168171
return self.visit_VARBINARY(type_, **kw)
169172

170-
def visit_DATETIME(self, type_, **kw):
173+
def visit_DATETIME(self, type_: Any, **kw: Dict[str, Any]) -> str:
171174
return self.visit_TIMESTAMP(type_, **kw)
172175

173-
def visit_TIMESTAMP(self, type_, **kw):
176+
def visit_TIMESTAMP(self, type_: Any, **kw: Dict[str, Any]) -> str:
174177
datatype = "TIMESTAMP"
175178
precision = getattr(type_, "precision", None)
176179
if precision not in range(0, 13) and precision is not None:
@@ -182,7 +185,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182185

183186
return datatype
184187

185-
def visit_TIME(self, type_, **kw):
188+
def visit_TIME(self, type_: Any, **kw: Dict[str, Any]) -> str:
186189
datatype = "TIME"
187190
precision = getattr(type_, "precision", None)
188191
if precision not in range(0, 13) and precision is not None:
@@ -193,13 +196,13 @@ def visit_TIME(self, type_, **kw):
193196
datatype += " WITH TIME ZONE"
194197
return datatype
195198

196-
def visit_JSON(self, type_, **kw):
199+
def visit_JSON(self, type_: Any, **kw: Dict[str, Any]) -> str:
197200
return 'JSON'
198201

199202

200203
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
201204
reserved_words = RESERVED_WORDS
202205

203-
def format_table(self, table, use_schema=True, name=None):
206+
def format_table(self, table: Table, use_schema: bool = True, name: Optional[str] = None) -> str:
204207
result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name)
205208
return TrinoSQLCompiler.add_catalog(result, table)

trino/sqlalchemy/datatype.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
# limitations under the License.
1212
import json
1313
import re
14-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
14+
from typing import Any, Dict, Iterator, List, Optional
15+
from typing import Text as typing_Text
16+
from typing import Tuple, Type, TypeVar, Union
1517

1618
from sqlalchemy import util
19+
from sqlalchemy.engine.interfaces import Dialect
1720
from sqlalchemy.sql import sqltypes
1821
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
1922
from sqlalchemy.types import String
2023

2124
SQLType = Union[TypeEngine, Type[TypeEngine]]
25+
_T = TypeVar('_T')
2226

2327

2428
class DOUBLE(sqltypes.Float):
@@ -38,7 +42,7 @@ def __init__(self, key_type: SQLType, value_type: SQLType):
3842
self.value_type: TypeEngine = value_type
3943

4044
@property
41-
def python_type(self):
45+
def python_type(self) -> type:
4246
return dict
4347

4448

@@ -53,36 +57,36 @@ def __init__(self, attr_types: List[Tuple[Optional[str], SQLType]]):
5357
self.attr_types.append((attr_name, attr_type))
5458

5559
@property
56-
def python_type(self):
60+
def python_type(self) -> type:
5761
return list
5862

5963

6064
class TIME(sqltypes.TIME):
6165
__visit_name__ = "TIME"
6266

63-
def __init__(self, precision=None, timezone=False):
67+
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
6468
super(TIME, self).__init__(timezone=timezone)
6569
self.precision = precision
6670

6771

6872
class TIMESTAMP(sqltypes.TIMESTAMP):
6973
__visit_name__ = "TIMESTAMP"
7074

71-
def __init__(self, precision=None, timezone=False):
75+
def __init__(self, precision: Optional[int] = None, timezone: bool = False):
7276
super(TIMESTAMP, self).__init__(timezone=timezone)
7377
self.precision = precision
7478

7579

7680
class JSON(TypeDecorator):
7781
impl = String
7882

79-
def process_bind_param(self, value, dialect):
83+
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Optional[typing_Text]:
8084
return json.dumps(value)
8185

82-
def process_result_value(self, value, dialect):
86+
def process_result_value(self, value: Union[str, bytes], dialect: Dialect) -> Optional[_T]:
8387
return json.loads(value)
8488

85-
def get_col_spec(self, **kw):
89+
def get_col_spec(self, **kw: Dict[str, Any]) -> str:
8690
return 'JSON'
8791

8892

0 commit comments

Comments
 (0)