Skip to content

Commit bdaf3b6

Browse files
committed
added type stub generation for dynamic functions
1 parent 86d807a commit bdaf3b6

9 files changed

+3506
-27
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# IDE
2+
.vscode
3+
.idea
4+
15
# Byte-compiled / optimized / DLL files
26
__pycache__/
37
*.py[cod]

_generate_type_stubs.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from pathlib import Path
2+
3+
from geoalchemy2._functions_helpers import _generate_stubs
4+
5+
"""
6+
this script is outside the geoalchemy2 package because the 'geoalchemy2.types'
7+
package interferes with the 'types' module in the standard library
8+
"""
9+
10+
script_dir = Path(__file__).resolve().parent
11+
12+
13+
if __name__ == "__main__":
14+
(script_dir / "geoalchemy2/functions.pyi").write_text(_generate_stubs())

geoalchemy2/_functions.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# -*- coding: utf-8 -*-
22
# flake8: noqa
3+
from typing import List
4+
from typing import Optional
5+
from typing import Tuple
6+
from typing import Union
7+
38
from geoalchemy2 import types
49

510
# fmt: off
6-
_FUNCTIONS = [
11+
_FUNCTIONS: List[Tuple[str, Optional[type], Union[None, str, Tuple[str, str]]]] = [
712
('AddGeometryColumn', None,
813
'''Adds a geometry column to an existing table.'''),
914
('DropGeometryColumn', None,

geoalchemy2/_functions_helpers.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import Callable
2+
from typing import Generic
3+
from typing import Optional
4+
from typing import Tuple
5+
from typing import TypeVar
6+
from typing import Union
7+
from typing import cast
8+
9+
from sqlalchemy.sql import functions
10+
from typing_extensions import ParamSpec
11+
12+
13+
def _get_docstring(name: str, doc: Union[None, str, Tuple[str, str]], type_: Optional[type]) -> str:
14+
doc_string_parts = []
15+
16+
if isinstance(doc, tuple):
17+
doc_string_parts.append(doc[0])
18+
doc_string_parts.append("see http://postgis.net/docs/{0}.html".format(doc[1]))
19+
elif doc is not None:
20+
doc_string_parts.append(doc)
21+
doc_string_parts.append("see http://postgis.net/docs/{0}.html".format(name))
22+
23+
if type_ is not None:
24+
return_type_str = "{0}.{1}".format(type_.__module__, type_.__name__)
25+
doc_string_parts.append("Return type: :class:`{0}`.".format(return_type_str))
26+
27+
return "\n\n".join(doc_string_parts)
28+
29+
30+
def _replace_indent(text: str, indent: str) -> str:
31+
lines = []
32+
for i, line in enumerate(text.splitlines()):
33+
if i == 0 or not line.strip():
34+
lines.append(line)
35+
else:
36+
lines.append(f"{indent}{line}")
37+
return "\n".join(lines)
38+
39+
40+
def _generate_stubs() -> str:
41+
"""Generates type stubs for the dynamic functions described in `geoalchemy2/_functions.py`."""
42+
from geoalchemy2._functions import _FUNCTIONS
43+
from geoalchemy2.functions import ST_AsGeoJSON
44+
45+
header = '''\
46+
# this file is automatically generated
47+
from typing import Any
48+
from typing import List
49+
50+
from sqlalchemy.sql import functions
51+
from sqlalchemy.sql.elements import ColumnElement
52+
53+
import geoalchemy2.types
54+
from geoalchemy2._functions_helpers import _generic_function
55+
56+
class GenericFunction(functions.GenericFunction): ...
57+
58+
class TableRowElement(ColumnElement):
59+
inherit_cache: bool = ...
60+
"""The cache is disabled for this class."""
61+
62+
def __init__(self, selectable: bool) -> None: ...
63+
@property
64+
def _from_objects(self) -> List[bool]: ... # type: ignore[override]
65+
'''
66+
stub_file_parts = [header]
67+
68+
functions = _FUNCTIONS.copy()
69+
functions.insert(0, ("ST_AsGeoJSON", str, ST_AsGeoJSON.__doc__))
70+
71+
for name, type_, doc in functions:
72+
doc = _replace_indent(_get_docstring(name, doc, type_), " ")
73+
74+
if type_ is None:
75+
type_str = "None"
76+
elif type_.__module__ == "builtins":
77+
type_str = type_.__name__
78+
else:
79+
type_str = f"{type_.__module__}.{type_.__name__}"
80+
81+
signature = f'''\
82+
@_generic_function
83+
def {name}(*args: Any, **kwargs: Any) -> {type_str}:
84+
"""{doc}"""
85+
...
86+
'''
87+
stub_file_parts.append(signature)
88+
89+
return "\n".join(stub_file_parts)
90+
91+
92+
_P = ParamSpec("_P")
93+
_R = TypeVar("_R", covariant=True)
94+
95+
96+
class _GenericFunction(functions.GenericFunction, Generic[_P, _R]):
97+
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # type: ignore[empty-body]
98+
...
99+
100+
101+
def _generic_function(func: Callable[_P, _R]) -> _GenericFunction[_P, _R]:
102+
"""Take a regular function and extend it with attributes from sqlalchemy GenericFunction.
103+
104+
based on https://github.com/python/mypy/issues/2087#issuecomment-1194111648
105+
"""
106+
return cast(_GenericFunction[_P, _R], func)

geoalchemy2/functions.py

+18-25
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
6868
"""
6969
import re
70+
from typing import List
7071
from typing import Type
7172

7273
from sqlalchemy import inspect
@@ -77,6 +78,7 @@
7778

7879
from geoalchemy2 import elements
7980
from geoalchemy2._functions import _FUNCTIONS
81+
from geoalchemy2._functions_helpers import _get_docstring
8082

8183
_GeoFunctionBase: Type[functions.GenericFunction]
8284
_GeoFunctionParent: Type[functions.GenericFunction]
@@ -131,11 +133,11 @@ class TableRowElement(ColumnElement):
131133
inherit_cache: bool = False
132134
"""The cache is disabled for this class."""
133135

134-
def __init__(self, selectable) -> None:
136+
def __init__(self, selectable: bool) -> None:
135137
self.selectable = selectable
136138

137139
@property
138-
def _from_objects(self):
140+
def _from_objects(self) -> List[bool]:
139141
return [self.selectable]
140142

141143

@@ -262,33 +264,24 @@ def __init__(self, *args, **kwargs) -> None:
262264
]
263265

264266

265-
# Iterate through _FUNCTIONS and create GenericFunction classes dynamically
266-
for name, type_, doc in _FUNCTIONS:
267-
attributes = {
268-
"name": name,
269-
"inherit_cache": True,
270-
}
271-
docs = []
267+
def _create_dynamic_functions() -> None:
268+
# Iterate through _FUNCTIONS and create GenericFunction classes dynamically
269+
for name, type_, doc in _FUNCTIONS:
270+
attributes = {
271+
"name": name,
272+
"inherit_cache": True,
273+
"__doc__": _get_docstring(name, doc, type_),
274+
}
272275

273-
if isinstance(doc, tuple):
274-
docs.append(doc[0])
275-
docs.append("see http://postgis.net/docs/{0}.html".format(doc[1]))
276-
elif doc is not None:
277-
docs.append(doc)
278-
docs.append("see http://postgis.net/docs/{0}.html".format(name))
276+
if type_ is not None:
277+
attributes["type"] = type_
279278

280-
if type_ is not None:
281-
attributes["type"] = type_
279+
globals()[name] = type(name, (GenericFunction,), attributes)
280+
__all__.append(name)
282281

283-
type_str = "{0}.{1}".format(type_.__module__, type_.__name__)
284-
docs.append("Return type: :class:`{0}`.".format(type_str))
285282

286-
if len(docs) != 0:
287-
attributes["__doc__"] = "\n\n".join(docs)
283+
_create_dynamic_functions()
288284

289-
globals()[name] = type(name, (GenericFunction,), attributes)
290-
__all__.append(name)
291285

292-
293-
def __dir__():
286+
def __dir__() -> list[str]:
294287
return __all__

0 commit comments

Comments
 (0)