Skip to content

Commit 912fc09

Browse files
authored
feat: add registry for extension functions (#68)
adds FunctionRegistry that handles lookup and type inference of extension functions. uses derivation expressions under the hood to generate return types. pyyaml is now part of `extensions` extra.
1 parent 2df0cd0 commit 912fc09

5 files changed

+702
-29
lines changed

pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ dynamic = ["version"]
1212
write_to = "src/substrait/_version.py"
1313

1414
[project.optional-dependencies]
15-
extensions = ["antlr4-python3-runtime"]
15+
extensions = ["antlr4-python3-runtime", "pyyaml"]
1616
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
17-
test = ["pytest >= 7.0.0", "antlr4-python3-runtime"]
17+
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
1818

1919
[tool.pytest.ini_options]
2020
pythonpath = "src"
@@ -31,7 +31,7 @@ target-version = "py39"
3131
# never autoformat upstream or generated code
3232
exclude = ["third_party/", "src/substrait/gen"]
3333
# do not autofix the following (will still get flagged in lint)
34-
unfixable = [
34+
lint.unfixable = [
3535
"F401", # unused imports
3636
"T201", # print statements
3737
]

src/substrait/derivation_expression.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def _evaluate(x, values: dict):
3737
elif type(x) == SubstraitTypeParser.FunctionCallContext:
3838
exprs = [_evaluate(e, values) for e in x.expr()]
3939
func = x.Identifier().symbol.text
40-
4140
if func == "min":
4241
return min(*exprs)
4342
elif func == "max":
@@ -48,27 +47,39 @@ def _evaluate(x, values: dict):
4847
scalar_type = x.scalarType()
4948
parametrized_type = x.parameterizedType()
5049
if scalar_type:
50+
nullability = (
51+
Type.NULLABILITY_NULLABLE if x.isnull else Type.NULLABILITY_REQUIRED
52+
)
5153
if isinstance(scalar_type, SubstraitTypeParser.I8Context):
52-
return Type(i8=Type.I8())
54+
return Type(i8=Type.I8(nullability=nullability))
5355
elif isinstance(scalar_type, SubstraitTypeParser.I16Context):
54-
return Type(i16=Type.I16())
56+
return Type(i16=Type.I16(nullability=nullability))
5557
elif isinstance(scalar_type, SubstraitTypeParser.I32Context):
56-
return Type(i32=Type.I32())
58+
return Type(i32=Type.I32(nullability=nullability))
5759
elif isinstance(scalar_type, SubstraitTypeParser.I64Context):
58-
return Type(i64=Type.I64())
60+
return Type(i64=Type.I64(nullability=nullability))
5961
elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context):
60-
return Type(fp32=Type.FP32())
62+
return Type(fp32=Type.FP32(nullability=nullability))
6163
elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context):
62-
return Type(fp64=Type.FP64())
64+
return Type(fp64=Type.FP64(nullability=nullability))
6365
elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext):
64-
return Type(bool=Type.Boolean())
66+
return Type(bool=Type.Boolean(nullability=nullability))
6567
else:
6668
raise Exception(f"Unknown scalar type {type(scalar_type)}")
6769
elif parametrized_type:
6870
if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext):
6971
precision = _evaluate(parametrized_type.precision, values)
7072
scale = _evaluate(parametrized_type.scale, values)
71-
return Type(decimal=Type.Decimal(precision=precision, scale=scale))
73+
nullability = (
74+
Type.NULLABILITY_NULLABLE
75+
if parametrized_type.isnull
76+
else Type.NULLABILITY_REQUIRED
77+
)
78+
return Type(
79+
decimal=Type.Decimal(
80+
precision=precision, scale=scale, nullability=nullability
81+
)
82+
)
7283
raise Exception(f"Unknown parametrized type {type(parametrized_type)}")
7384
else:
7485
raise Exception("either scalar_type or parametrized_type is required")
@@ -91,12 +102,18 @@ def _evaluate(x, values: dict):
91102
return _evaluate(x.finalType, values)
92103
elif type(x) == SubstraitTypeParser.TypeLiteralContext:
93104
return _evaluate(x.type_(), values)
105+
elif type(x) == SubstraitTypeParser.NumericLiteralContext:
106+
return int(str(x.Number()))
94107
else:
95108
raise Exception(f"Unknown token type {type(x)}")
96109

97110

98-
def evaluate(x: str, values: Optional[dict] = None):
111+
def _parse(x: str):
99112
lexer = SubstraitTypeLexer(InputStream(x))
100113
stream = CommonTokenStream(lexer)
101114
parser = SubstraitTypeParser(stream)
102-
return _evaluate(parser.expr(), values)
115+
return parser.expr()
116+
117+
118+
def evaluate(x: str, values: Optional[dict] = None):
119+
return _evaluate(_parse(x), values)

src/substrait/function_registry.py

+287
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType
2+
from substrait.gen.proto.type_pb2 import Type
3+
from importlib.resources import files as importlib_files
4+
import itertools
5+
from collections import defaultdict
6+
from collections.abc import Mapping
7+
from pathlib import Path
8+
from typing import Any, Optional, Union
9+
from .derivation_expression import evaluate, _evaluate, _parse
10+
11+
import yaml
12+
13+
14+
DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions"
15+
16+
17+
# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names
18+
_normalized_key_names = {
19+
"i8": "i8",
20+
"i16": "i16",
21+
"i32": "i32",
22+
"i64": "i64",
23+
"fp32": "fp32",
24+
"fp64": "fp64",
25+
"string": "str",
26+
"binary": "vbin",
27+
"boolean": "bool",
28+
"timestamp": "ts",
29+
"timestamp_tz": "tstz",
30+
"date": "date",
31+
"time": "time",
32+
"interval_year": "iyear",
33+
"interval_day": "iday",
34+
"interval_compound": "icompound",
35+
"uuid": "uuid",
36+
"fixedchar": "fchar",
37+
"varchar": "vchar",
38+
"fixedbinary": "fbin",
39+
"decimal": "dec",
40+
"precision_time": "pt",
41+
"precision_timestamp": "pts",
42+
"precision_timestamp_tz": "ptstz",
43+
"struct": "struct",
44+
"list": "list",
45+
"map": "map",
46+
}
47+
48+
49+
def normalize_substrait_type_names(typ: str) -> str:
50+
# Strip type specifiers
51+
typ = typ.split("<")[0]
52+
# First strip nullability marker
53+
typ = typ.strip("?").lower()
54+
55+
if typ.startswith("any"):
56+
return "any"
57+
elif typ.startswith("u!"):
58+
return typ
59+
elif typ in _normalized_key_names:
60+
return _normalized_key_names[typ]
61+
else:
62+
raise Exception(f"Unrecognized substrait type {typ}")
63+
64+
65+
def violates_integer_option(actual: int, option, parameters: dict):
66+
if isinstance(option, SubstraitTypeParser.NumericLiteralContext):
67+
return actual != int(str(option.Number()))
68+
elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext):
69+
parameter_name = str(option.Identifier())
70+
if parameter_name in parameters and parameters[parameter_name] != actual:
71+
return True
72+
else:
73+
parameters[parameter_name] = actual
74+
else:
75+
raise Exception(
76+
f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead"
77+
)
78+
79+
return False
80+
81+
82+
from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser
83+
84+
85+
def types_equal(type1: Type, type2: Type, check_nullability=False):
86+
if check_nullability:
87+
return type1 == type2
88+
else:
89+
x, y = Type(), Type()
90+
x.CopyFrom(type1)
91+
y.CopyFrom(type2)
92+
x.__getattribute__(
93+
x.WhichOneof("kind")
94+
).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED
95+
y.__getattribute__(
96+
y.WhichOneof("kind")
97+
).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED
98+
return x == y
99+
100+
101+
def covers(
102+
covered: Type,
103+
covering: SubstraitTypeParser.TypeLiteralContext,
104+
parameters: dict,
105+
check_nullability=False,
106+
):
107+
if isinstance(covering, SubstraitTypeParser.TypeParamContext):
108+
parameter_name = str(covering.Identifier())
109+
110+
if parameter_name in parameters:
111+
covering = parameters[parameter_name]
112+
113+
return types_equal(covering, covered, check_nullability)
114+
else:
115+
parameters[parameter_name] = covered
116+
return True
117+
118+
covering = covering.type_()
119+
scalar_type = covering.scalarType()
120+
if scalar_type:
121+
covering = _evaluate(covering, {})
122+
return types_equal(covering, covered, check_nullability)
123+
124+
parameterized_type = covering.parameterizedType()
125+
if parameterized_type:
126+
if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext):
127+
if covered.WhichOneof("kind") != "decimal":
128+
return False
129+
130+
nullability = (
131+
Type.NULLABILITY_NULLABLE
132+
if parameterized_type.isnull
133+
else Type.NULLABILITY_REQUIRED
134+
)
135+
136+
if (
137+
check_nullability
138+
and nullability
139+
!= covered.__getattribute__(covered.WhichOneof("kind")).nullability
140+
):
141+
return False
142+
143+
return not (
144+
violates_integer_option(
145+
covered.decimal.scale, parameterized_type.scale, parameters
146+
)
147+
or violates_integer_option(
148+
covered.decimal.precision, parameterized_type.precision, parameters
149+
)
150+
)
151+
else:
152+
raise Exception(f"Unhandled type {type(parameterized_type)}")
153+
154+
any_type = covering.anyType()
155+
if any_type:
156+
return True
157+
158+
159+
class FunctionEntry:
160+
def __init__(
161+
self, uri: str, name: str, impl: Mapping[str, Any], anchor: int
162+
) -> None:
163+
self.name = name
164+
self.normalized_inputs: list = []
165+
self.uri: str = uri
166+
self.anchor = anchor
167+
self.arguments = []
168+
self.rtn = impl["return"]
169+
self.nullability = impl.get("nullability", "MIRROR")
170+
self.variadic = impl.get("variadic", False)
171+
if input_args := impl.get("args", []):
172+
for val in input_args:
173+
if typ := val.get("value"):
174+
self.arguments.append(_parse(typ))
175+
self.normalized_inputs.append(normalize_substrait_type_names(typ))
176+
elif arg_name := val.get("name", None):
177+
self.arguments.append(val.get("options"))
178+
self.normalized_inputs.append("req")
179+
180+
def __repr__(self) -> str:
181+
return f"{self.name}:{'_'.join(self.normalized_inputs)}"
182+
183+
def satisfies_signature(self, signature: tuple) -> Optional[str]:
184+
if self.variadic:
185+
min_args_allowed = self.variadic.get("min", 0)
186+
if len(signature) < min_args_allowed:
187+
return None
188+
inputs = [self.arguments[0]] * len(signature)
189+
else:
190+
inputs = self.arguments
191+
if len(inputs) != len(signature):
192+
return None
193+
194+
zipped_args = list(zip(inputs, signature))
195+
196+
parameters = {}
197+
198+
for x, y in zipped_args:
199+
if type(y) == str:
200+
if y not in x:
201+
return None
202+
else:
203+
if not covers(
204+
y, x, parameters, check_nullability=self.nullability == "DISCRETE"
205+
):
206+
return None
207+
208+
output_type = evaluate(self.rtn, parameters)
209+
210+
if self.nullability == "MIRROR":
211+
sig_contains_nullable = any(
212+
[
213+
p.__getattribute__(p.WhichOneof("kind")).nullability
214+
== Type.NULLABILITY_NULLABLE
215+
for p in signature
216+
if type(p) == Type
217+
]
218+
)
219+
output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = (
220+
Type.NULLABILITY_NULLABLE
221+
if sig_contains_nullable
222+
else Type.NULLABILITY_REQUIRED
223+
)
224+
225+
return output_type
226+
227+
228+
class FunctionRegistry:
229+
def __init__(self, load_default_extensions=True) -> None:
230+
self._function_mapping: dict = defaultdict(dict)
231+
self._id_generator = itertools.count(1)
232+
233+
self._uri_aliases = {}
234+
235+
if load_default_extensions:
236+
for fpath in importlib_files("substrait.extensions").glob( # type: ignore
237+
"functions*.yaml"
238+
):
239+
uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}"
240+
self._uri_aliases[fpath.name] = uri
241+
self.register_extension_yaml(fpath, uri)
242+
243+
def register_extension_yaml(
244+
self,
245+
fname: Union[str, Path],
246+
uri: str,
247+
) -> None:
248+
fname = Path(fname)
249+
with open(fname) as f: # type: ignore
250+
extension_definitions = yaml.safe_load(f)
251+
252+
self.register_extension_dict(extension_definitions, uri)
253+
254+
def register_extension_dict(self, definitions: dict, uri: str) -> None:
255+
for named_functions in definitions.values():
256+
for function in named_functions:
257+
for impl in function.get("impls", []):
258+
func = FunctionEntry(
259+
uri, function["name"], impl, next(self._id_generator)
260+
)
261+
if (
262+
func.uri in self._function_mapping
263+
and function["name"] in self._function_mapping[func.uri]
264+
):
265+
self._function_mapping[func.uri][function["name"]].append(func)
266+
else:
267+
self._function_mapping[func.uri][function["name"]] = [func]
268+
269+
# TODO add an optional return type check
270+
def lookup_function(
271+
self, uri: str, function_name: str, signature: tuple
272+
) -> Optional[tuple[FunctionEntry, Type]]:
273+
uri = self._uri_aliases.get(uri, uri)
274+
275+
if (
276+
uri not in self._function_mapping
277+
or function_name not in self._function_mapping[uri]
278+
):
279+
return None
280+
functions = self._function_mapping[uri][function_name]
281+
for f in functions:
282+
assert isinstance(f, FunctionEntry)
283+
rtn = f.satisfies_signature(signature)
284+
if rtn is not None:
285+
return (f, rtn)
286+
287+
return None

0 commit comments

Comments
 (0)