|
| 1 | +from typing import Callable |
| 2 | +from typing import Generic |
| 3 | +from typing import Optional |
| 4 | +from typing import Tuple |
| 5 | +from typing import TypeVar |
| 6 | +from typing import Union |
| 7 | +from typing import cast |
| 8 | + |
| 9 | +from sqlalchemy.sql import functions |
| 10 | +from typing_extensions import ParamSpec |
| 11 | + |
| 12 | + |
| 13 | +def _get_docstring(name: str, doc: Union[None, str, Tuple[str, str]], type_: Optional[type]) -> str: |
| 14 | + doc_string_parts = [] |
| 15 | + |
| 16 | + if isinstance(doc, tuple): |
| 17 | + doc_string_parts.append(doc[0]) |
| 18 | + doc_string_parts.append("see http://postgis.net/docs/{0}.html".format(doc[1])) |
| 19 | + elif doc is not None: |
| 20 | + doc_string_parts.append(doc) |
| 21 | + doc_string_parts.append("see http://postgis.net/docs/{0}.html".format(name)) |
| 22 | + |
| 23 | + if type_ is not None: |
| 24 | + return_type_str = "{0}.{1}".format(type_.__module__, type_.__name__) |
| 25 | + doc_string_parts.append("Return type: :class:`{0}`.".format(return_type_str)) |
| 26 | + |
| 27 | + return "\n\n".join(doc_string_parts) |
| 28 | + |
| 29 | + |
| 30 | +def _replace_indent(text: str, indent: str) -> str: |
| 31 | + lines = [] |
| 32 | + for i, line in enumerate(text.splitlines()): |
| 33 | + if i == 0 or not line.strip(): |
| 34 | + lines.append(line) |
| 35 | + else: |
| 36 | + lines.append(f"{indent}{line}") |
| 37 | + return "\n".join(lines) |
| 38 | + |
| 39 | + |
| 40 | +def _generate_stubs() -> str: |
| 41 | + """Generates type stubs for the dynamic functions described in `geoalchemy2/_functions.py`.""" |
| 42 | + from geoalchemy2._functions import _FUNCTIONS |
| 43 | + from geoalchemy2.functions import ST_AsGeoJSON |
| 44 | + |
| 45 | + header = '''\ |
| 46 | +# this file is automatically generated |
| 47 | +from typing import Any |
| 48 | +from typing import List |
| 49 | +
|
| 50 | +from sqlalchemy.sql import functions |
| 51 | +from sqlalchemy.sql.elements import ColumnElement |
| 52 | +
|
| 53 | +import geoalchemy2.types |
| 54 | +from geoalchemy2._functions_helpers import _generic_function |
| 55 | +
|
| 56 | +class GenericFunction(functions.GenericFunction): ... |
| 57 | +
|
| 58 | +class TableRowElement(ColumnElement): |
| 59 | + inherit_cache: bool = ... |
| 60 | + """The cache is disabled for this class.""" |
| 61 | +
|
| 62 | + def __init__(self, selectable: bool) -> None: ... |
| 63 | + @property |
| 64 | + def _from_objects(self) -> List[bool]: ... # type: ignore[override] |
| 65 | +''' |
| 66 | + stub_file_parts = [header] |
| 67 | + |
| 68 | + functions = _FUNCTIONS.copy() |
| 69 | + functions.insert(0, ("ST_AsGeoJSON", str, ST_AsGeoJSON.__doc__)) |
| 70 | + |
| 71 | + for name, type_, doc in functions: |
| 72 | + doc = _replace_indent(_get_docstring(name, doc, type_), " ") |
| 73 | + |
| 74 | + if type_ is None: |
| 75 | + type_str = "None" |
| 76 | + elif type_.__module__ == "builtins": |
| 77 | + type_str = type_.__name__ |
| 78 | + else: |
| 79 | + type_str = f"{type_.__module__}.{type_.__name__}" |
| 80 | + |
| 81 | + signature = f'''\ |
| 82 | +@_generic_function |
| 83 | +def {name}(*args: Any, **kwargs: Any) -> {type_str}: |
| 84 | + """{doc}""" |
| 85 | + ... |
| 86 | +''' |
| 87 | + stub_file_parts.append(signature) |
| 88 | + |
| 89 | + return "\n".join(stub_file_parts) |
| 90 | + |
| 91 | + |
| 92 | +_P = ParamSpec("_P") |
| 93 | +_R = TypeVar("_R", covariant=True) |
| 94 | + |
| 95 | + |
| 96 | +class _GenericFunction(functions.GenericFunction, Generic[_P, _R]): |
| 97 | + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # type: ignore[empty-body] |
| 98 | + ... |
| 99 | + |
| 100 | + |
| 101 | +def _generic_function(func: Callable[_P, _R]) -> _GenericFunction[_P, _R]: |
| 102 | + """Take a regular function and extend it with attributes from sqlalchemy GenericFunction. |
| 103 | +
|
| 104 | + based on https://github.com/python/mypy/issues/2087#issuecomment-1194111648 |
| 105 | + """ |
| 106 | + return cast(_GenericFunction[_P, _R], func) |
0 commit comments