From 1efb3f062a501697fd41db851acd087e1a44f05f Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Mon, 10 Mar 2025 18:11:24 +0000 Subject: [PATCH 1/7] feat: add registry for extension functions --- pyproject.toml | 6 +- src/substrait/function_registry.py | 316 +++++++++++++++++++++++++++++ tests/test_function_registry.py | 220 ++++++++++++++++++++ 3 files changed, 539 insertions(+), 3 deletions(-) create mode 100644 src/substrait/function_registry.py create mode 100644 tests/test_function_registry.py diff --git a/pyproject.toml b/pyproject.toml index 808aa28..4c4ab62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,9 @@ dynamic = ["version"] write_to = "src/substrait/_version.py" [project.optional-dependencies] -extensions = ["antlr4-python3-runtime"] +extensions = ["antlr4-python3-runtime", "pyyaml"] gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"] -test = ["pytest >= 7.0.0", "antlr4-python3-runtime"] +test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"] [tool.pytest.ini_options] pythonpath = "src" @@ -31,7 +31,7 @@ target-version = "py39" # never autoformat upstream or generated code exclude = ["third_party/", "src/substrait/gen"] # do not autofix the following (will still get flagged in lint) -unfixable = [ +lint.unfixable = [ "F401", # unused imports "T201", # print statements ] diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py new file mode 100644 index 0000000..02eae36 --- /dev/null +++ b/src/substrait/function_registry.py @@ -0,0 +1,316 @@ +from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType +from substrait.gen.proto.type_pb2 import Type +from importlib.resources import files as importlib_files +import itertools +from collections import defaultdict +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Optional, Union +from .derivation_expression import evaluate + +import yaml +import re + +_normalized_key_names = { + "binary": "vbin", + "interval_compound": "icompound", + "interval_day": "iday", + "interval_year": "iyear", + "string": "str", + "timestamp": "ts", + "timestamp_tz": "tstz", +} + + +def normalize_substrait_type_names(typ: str) -> str: + # First strip off any punctuation + typ = typ.strip("?").lower() + + # Common prefixes whose information does not matter to an extension function + # signature + for complex_type, abbr in [ + ("fixedchar", "fchar"), + ("varchar", "vchar"), + ("fixedbinary", "fbin"), + ("decimal", "dec"), + ("precision_timestamp", "pts"), + ("precision_timestamp_tz", "ptstz"), + ("struct", "struct"), + ("list", "list"), + ("map", "map"), + ("any", "any"), + ("boolean", "bool"), + ]: + if typ.lower().startswith(complex_type): + typ = abbr + + # Then pass through the dictionary of mappings, defaulting to just the + # existing string + typ = _normalized_key_names.get(typ.lower(), typ.lower()) + return typ + + +id_generator = itertools.count(1) + + +def to_integer_option(txt: str): + if txt.isnumeric(): + return ParameterizedType.IntegerOption(literal=int(txt)) + else: + return ParameterizedType.IntegerOption( + parameter=ParameterizedType.IntegerParameter(name=txt) + ) + + +def to_parameterized_type(dtype: str): + if dtype == "boolean": + return ParameterizedType(bool=Type.Boolean()) + elif dtype == "i8": + return ParameterizedType(i8=Type.I8()) + elif dtype == "i16": + return ParameterizedType(i16=Type.I16()) + elif dtype == "i32": + return ParameterizedType(i32=Type.I32()) + elif dtype == "i64": + return ParameterizedType(i64=Type.I64()) + elif dtype == "fp32": + return ParameterizedType(fp32=Type.FP32()) + elif dtype == "fp64": + return ParameterizedType(fp64=Type.FP64()) + elif dtype == "timestamp": + return ParameterizedType(timestamp=Type.Timestamp()) + elif dtype == "timestamp_tz": + return ParameterizedType(timestamp_tz=Type.TimestampTZ()) + elif dtype == "date": + return ParameterizedType(date=Type.Date()) + elif dtype == "time": + return ParameterizedType(time=Type.Time()) + elif dtype == "interval_year": + return ParameterizedType(interval_year=Type.IntervalYear()) + elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"): + (_, precision, scale, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + decimal=ParameterizedType.ParameterizedDecimal( + scale=to_integer_option(scale), precision=to_integer_option(precision) + ) + ) + elif dtype.startswith("varchar"): + (_, length, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + varchar=ParameterizedType.ParameterizedVarChar( + length=to_integer_option(length) + ) + ) + elif dtype.startswith("precision_timestamp"): + (_, precision, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp( + precision=to_integer_option(precision) + ) + ) + elif dtype.startswith("precision_timestamp_tz"): + (_, precision, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ( + precision=to_integer_option(precision) + ) + ) + elif dtype.startswith("fixedchar"): + (_, length, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + fixed_char=ParameterizedType.ParameterizedFixedChar( + length=to_integer_option(length) + ) + ) + elif dtype == "string": + return ParameterizedType(string=Type.String()) + elif dtype.startswith("list"): + inner_dtype = dtype[5:-1] + return ParameterizedType( + list=ParameterizedType.ParameterizedList( + type=to_parameterized_type(inner_dtype) + ) + ) + elif dtype.startswith("interval_day"): + (_, precision, _) = re.split(r"\W+", dtype) + + return ParameterizedType( + interval_day=ParameterizedType.ParameterizedIntervalDay( + precision=to_integer_option(precision) + ) + ) + elif dtype.startswith("any"): + return ParameterizedType( + type_parameter=ParameterizedType.TypeParameter(name=dtype) + ) + elif dtype.startswith("u!") or dtype == "geometry": + return ParameterizedType( + user_defined=ParameterizedType.ParameterizedUserDefined() + ) + else: + raise Exception(f"Unkownn type - {dtype}") + + +def violates_integer_option( + actual: int, option: ParameterizedType.IntegerOption, parameters: dict +): + integer_type = option.WhichOneof("integer_type") + + if integer_type == "literal" and actual != option.literal: + return True + else: + parameter_name = option.parameter.name + if parameter_name in parameters and parameters[parameter_name] != actual: + return True + else: + parameters[parameter_name] = actual + + return False + + +def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict): + expected_kind = parameterized_type.WhichOneof("kind") + + if expected_kind == "type_parameter": + parameter_name = parameterized_type.type_parameter.name + if parameter_name == "any": + return True + else: + if parameter_name in parameters and parameters[ + parameter_name + ].SerializeToString(deterministic=True) != dtype.SerializeToString( + deterministic=True + ): + return False + else: + parameters[parameter_name] = dtype + return True + + kind = dtype.WhichOneof("kind") + + if kind != expected_kind: + return False + + if kind == "decimal": + if violates_integer_option( + dtype.decimal.scale, parameterized_type.decimal.scale, parameters + ) or violates_integer_option( + dtype.decimal.precision, parameterized_type.decimal.precision, parameters + ): + return False + + # TODO handle all types + + return True + + +class FunctionEntry: + def __init__(self, uri: str, name: str, impl: Mapping[str, Any]) -> None: + self.name = name + self.normalized_inputs: list = [] + self.uri: str = uri + self.anchor = next(id_generator) + self.arguments = [] + self.rtn = impl["return"] + self.nullability = impl.get("nullability", False) + self.variadic = impl.get("variadic", False) + if input_args := impl.get("args", []): + for val in input_args: + if typ := val.get("value"): + self.arguments.append(to_parameterized_type(typ.strip("?"))) + self.normalized_inputs.append(normalize_substrait_type_names(typ)) + elif arg_name := val.get("name", None): + self.arguments.append(val.get("options")) + self.normalized_inputs.append("req") + + def __repr__(self) -> str: + return f"{self.name}:{'_'.join(self.normalized_inputs)}" + + def satisfies_signature(self, signature: tuple) -> Optional[str]: + if self.variadic: + min_args_allowed = self.variadic.get("min", 0) + if len(signature) < min_args_allowed: + return None + inputs = [self.arguments[0]] * len(signature) + else: + inputs = self.arguments + if len(inputs) != len(signature): + return None + + zipped_args = list(zip(inputs, signature)) + + parameters = {} + + for x, y in zipped_args: + if type(y) == str: + if y not in x: + return None + else: + if not covers(y, x, parameters): + return None + + return evaluate(self.rtn, parameters) + + +class FunctionRegistry: + def __init__(self) -> None: + self._function_mapping: dict = defaultdict(dict) + self.id_generator = itertools.count(1) + + self.uri_aliases = {} + + for fpath in importlib_files("substrait.extensions").glob( # type: ignore + "functions*.yaml" + ): + uri = f"https://github.com/substrait-io/substrait/blob/main/extensions/{fpath.name}" + self.uri_aliases[fpath.name] = uri + self.register_extension_yaml(fpath, uri) + + def register_extension_yaml( + self, + fname: Union[str, Path], + uri: str, + ) -> None: + fname = Path(fname) + with open(fname) as f: # type: ignore + extension_definitions = yaml.safe_load(f) + + self.register_extension_dict(extension_definitions, uri) + + def register_extension_dict(self, definitions: dict, uri: str) -> None: + for named_functions in definitions.values(): + for function in named_functions: + for impl in function.get("impls", []): + func = FunctionEntry(uri, function["name"], impl) + if ( + func.uri in self._function_mapping + and function["name"] in self._function_mapping[func.uri] + ): + self._function_mapping[func.uri][function["name"]].append(func) + else: + self._function_mapping[func.uri][function["name"]] = [func] + + # TODO add an optional return type check + def lookup_function( + self, uri: str, function_name: str, signature: tuple + ) -> Optional[tuple[FunctionEntry, Type]]: + uri = self.uri_aliases.get(uri, uri) + + if ( + uri not in self._function_mapping + or function_name not in self._function_mapping[uri] + ): + return None + functions = self._function_mapping[uri][function_name] + for f in functions: + assert isinstance(f, FunctionEntry) + rtn = f.satisfies_signature(signature) + if rtn is not None: + return (f, rtn) + + return None diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py new file mode 100644 index 0000000..20e4a26 --- /dev/null +++ b/tests/test_function_registry.py @@ -0,0 +1,220 @@ +import yaml + +from substrait.gen.proto.type_pb2 import Type +from substrait.function_registry import FunctionRegistry + +content = """%YAML 1.2 +--- +scalar_functions: + - name: "test_fn" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + - name: "test_fn_variadic_any" + description: "" + impls: + - args: + - value: any1 + variadic: + min: 2 + return: any1 + - name: "add" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i8 + - args: + - name: x + value: i8 + - name: y + value: i8 + - name: z + value: any + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i16 + - args: + - name: x + value: any1 + - name: y + value: any1 + - name: z + value: any2 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: any2 + - name: "test_decimal" + impls: + - args: + - name: x + value: decimal + - name: y + value: decimal + return: decimal + - name: "test_enum" + impls: + - args: + - name: op + options: [ INTACT, FLIP ] + - name: x + value: i8 + return: i8 + +""" + + +registry = FunctionRegistry() + +registry.register_extension_dict(yaml.safe_load(content), uri="test") + + +def i8(): + return Type(i8=Type.I8()) + + +def i16(): + return Type(i16=Type.I16()) + + +def bool(): + return Type(bool=Type.Boolean()) + + +def decimal(precision, scale): + return Type(decimal=Type.Decimal(scale=scale, precision=precision)) + + +def test_non_existing_uri(): + assert ( + registry.lookup_function( + uri="non_existent", function_name="add", signature=[i8(), i8()] + ) + is None + ) + + +def test_non_existing_function(): + assert ( + registry.lookup_function( + uri="test", function_name="sub", signature=[i8(), i8()] + ) + is None + ) + + +def test_non_existing_function_signature(): + assert ( + registry.lookup_function(uri="test", function_name="add", signature=[i8()]) + is None + ) + + +def test_exact_match(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i8()] + )[1] == Type(i8=Type.I8()) + + +def test_wildcard_match(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i8(), bool()] + )[1] == Type(i16=Type.I16()) + + +def test_wildcard_match_fails_with_constraits(): + assert ( + registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i16(), i16()] + ) + is None + ) + + +def test_wildcard_match_with_constraits(): + assert ( + registry.lookup_function( + uri="test", function_name="add", signature=[i16(), i16(), i8()] + )[1] + == i8() + ) + + +def test_variadic(): + assert ( + registry.lookup_function( + uri="test", function_name="test_fn", signature=[i8(), i8(), i8()] + )[1] + == i8() + ) + + +def test_variadic_any(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_fn_variadic_any", + signature=[i16(), i16(), i16()], + )[1] + == i16() + ) + + +def test_variadic_fails_min_constraint(): + assert ( + registry.lookup_function(uri="test", function_name="test_fn", signature=[i8()]) + is None + ) + + +def test_decimal_happy_path(): + assert registry.lookup_function( + uri="test", + function_name="test_decimal", + signature=[decimal(10, 8), decimal(8, 6)], + )[1] == decimal(11, 7) + + +def test_decimal_violates_constraint(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_decimal", + signature=[decimal(10, 8), decimal(12, 10)], + ) + is None + ) + + +def test_enum_with_valid_option(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_enum", + signature=["FLIP", i8()], + )[1] + == i8() + ) + + +def test_enum_with_nonexistent_option(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_enum", + signature=["NONEXISTENT", i8()], + ) + is None + ) From 050ca8724622df2f69cf4d0f912bbb31c25783b1 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Tue, 11 Mar 2025 12:42:02 +0000 Subject: [PATCH 2/7] fix: remove global id_generator in function registry --- src/substrait/function_registry.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index 02eae36..3a25420 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -50,9 +50,6 @@ def normalize_substrait_type_names(typ: str) -> str: return typ -id_generator = itertools.count(1) - - def to_integer_option(txt: str): if txt.isnumeric(): return ParameterizedType.IntegerOption(literal=int(txt)) @@ -153,7 +150,7 @@ def to_parameterized_type(dtype: str): user_defined=ParameterizedType.ParameterizedUserDefined() ) else: - raise Exception(f"Unkownn type - {dtype}") + raise Exception(f"Unknown type - {dtype}") def violates_integer_option( @@ -210,11 +207,11 @@ def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict) class FunctionEntry: - def __init__(self, uri: str, name: str, impl: Mapping[str, Any]) -> None: + def __init__(self, uri: str, name: str, impl: Mapping[str, Any], anchor: int) -> None: self.name = name self.normalized_inputs: list = [] self.uri: str = uri - self.anchor = next(id_generator) + self.anchor = anchor self.arguments = [] self.rtn = impl["return"] self.nullability = impl.get("nullability", False) @@ -286,7 +283,7 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None: for named_functions in definitions.values(): for function in named_functions: for impl in function.get("impls", []): - func = FunctionEntry(uri, function["name"], impl) + func = FunctionEntry(uri, function["name"], impl, next(self.id_generator)) if ( func.uri in self._function_mapping and function["name"] in self._function_mapping[func.uri] From e47abad71dd4d9c571f13e1e1c6a64bd07820228 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Tue, 11 Mar 2025 13:26:33 +0000 Subject: [PATCH 3/7] fix: improve signature name handling --- src/substrait/function_registry.py | 117 +++++++++++++++++------------ 1 file changed, 68 insertions(+), 49 deletions(-) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index 3a25420..2dfd84e 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -11,43 +11,50 @@ import yaml import re +# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names _normalized_key_names = { - "binary": "vbin", - "interval_compound": "icompound", - "interval_day": "iday", - "interval_year": "iyear", + "i8": "i8", + "i16": "i16", + "i32": "i32", + "i64": "i64", + "fp32": "fp32", + "fp64": "fp64", "string": "str", + "binary": "vbin", + "boolean": "bool", "timestamp": "ts", "timestamp_tz": "tstz", + "date": "date", + "time": "time", + "interval_year": "iyear", + "interval_day": "iday", + "interval_compound": "icompound", + "uuid": "uuid", + "fixedchar": "fchar", + "varchar": "vchar", + "fixedbinary": "fbin", + "decimal": "dec", + "precision_time": "pt", + "precision_timestamp": "pts", + "precision_timestamp_tz": "ptstz", + "struct": "struct", + "list": "list", + "map": "map", } def normalize_substrait_type_names(typ: str) -> str: - # First strip off any punctuation + # First strip nullability marker typ = typ.strip("?").lower() + # Strip type specifiers + typ = typ.split('<')[0] - # Common prefixes whose information does not matter to an extension function - # signature - for complex_type, abbr in [ - ("fixedchar", "fchar"), - ("varchar", "vchar"), - ("fixedbinary", "fbin"), - ("decimal", "dec"), - ("precision_timestamp", "pts"), - ("precision_timestamp_tz", "ptstz"), - ("struct", "struct"), - ("list", "list"), - ("map", "map"), - ("any", "any"), - ("boolean", "bool"), - ]: - if typ.lower().startswith(complex_type): - typ = abbr - - # Then pass through the dictionary of mappings, defaulting to just the - # existing string - typ = _normalized_key_names.get(typ.lower(), typ.lower()) - return typ + if typ.startswith("any"): + return "any" + elif typ.startswith("u!"): + return typ + else: + return _normalized_key_names[typ] def to_integer_option(txt: str): @@ -60,36 +67,42 @@ def to_integer_option(txt: str): def to_parameterized_type(dtype: str): + if dtype.endswith('?'): + dtype = dtype[:-1] + nullability = Type.NULLABILITY_NULLABLE + else: + nullability = Type.NULLABILITY_REQUIRED + if dtype == "boolean": - return ParameterizedType(bool=Type.Boolean()) + return ParameterizedType(bool=Type.Boolean(nullability=nullability)) elif dtype == "i8": - return ParameterizedType(i8=Type.I8()) + return ParameterizedType(i8=Type.I8(nullability=nullability)) elif dtype == "i16": - return ParameterizedType(i16=Type.I16()) + return ParameterizedType(i16=Type.I16(nullability=nullability)) elif dtype == "i32": - return ParameterizedType(i32=Type.I32()) + return ParameterizedType(i32=Type.I32(nullability=nullability)) elif dtype == "i64": - return ParameterizedType(i64=Type.I64()) + return ParameterizedType(i64=Type.I64(nullability=nullability)) elif dtype == "fp32": - return ParameterizedType(fp32=Type.FP32()) + return ParameterizedType(fp32=Type.FP32(nullability=nullability)) elif dtype == "fp64": - return ParameterizedType(fp64=Type.FP64()) + return ParameterizedType(fp64=Type.FP64(nullability=nullability)) elif dtype == "timestamp": - return ParameterizedType(timestamp=Type.Timestamp()) + return ParameterizedType(timestamp=Type.Timestamp(nullability=nullability)) elif dtype == "timestamp_tz": - return ParameterizedType(timestamp_tz=Type.TimestampTZ()) + return ParameterizedType(timestamp_tz=Type.TimestampTZ(nullability=nullability)) elif dtype == "date": - return ParameterizedType(date=Type.Date()) + return ParameterizedType(date=Type.Date(nullability=nullability)) elif dtype == "time": - return ParameterizedType(time=Type.Time()) + return ParameterizedType(time=Type.Time(nullability=nullability)) elif dtype == "interval_year": - return ParameterizedType(interval_year=Type.IntervalYear()) + return ParameterizedType(interval_year=Type.IntervalYear(nullability=nullability)) elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"): (_, precision, scale, _) = re.split(r"\W+", dtype) return ParameterizedType( decimal=ParameterizedType.ParameterizedDecimal( - scale=to_integer_option(scale), precision=to_integer_option(precision) + scale=to_integer_option(scale), precision=to_integer_option(precision), nullability=nullability ) ) elif dtype.startswith("varchar"): @@ -97,7 +110,8 @@ def to_parameterized_type(dtype: str): return ParameterizedType( varchar=ParameterizedType.ParameterizedVarChar( - length=to_integer_option(length) + length=to_integer_option(length), + nullability=nullability ) ) elif dtype.startswith("precision_timestamp"): @@ -105,7 +119,8 @@ def to_parameterized_type(dtype: str): return ParameterizedType( precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp( - precision=to_integer_option(precision) + precision=to_integer_option(precision), + nullability=nullability ) ) elif dtype.startswith("precision_timestamp_tz"): @@ -113,7 +128,8 @@ def to_parameterized_type(dtype: str): return ParameterizedType( precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ( - precision=to_integer_option(precision) + precision=to_integer_option(precision), + nullability=nullability ) ) elif dtype.startswith("fixedchar"): @@ -121,16 +137,18 @@ def to_parameterized_type(dtype: str): return ParameterizedType( fixed_char=ParameterizedType.ParameterizedFixedChar( - length=to_integer_option(length) + length=to_integer_option(length), + nullability=nullability ) ) elif dtype == "string": - return ParameterizedType(string=Type.String()) + return ParameterizedType(string=Type.String(nullability=nullability)) elif dtype.startswith("list"): inner_dtype = dtype[5:-1] return ParameterizedType( list=ParameterizedType.ParameterizedList( - type=to_parameterized_type(inner_dtype) + type=to_parameterized_type(inner_dtype), + nullability=nullability ) ) elif dtype.startswith("interval_day"): @@ -138,7 +156,8 @@ def to_parameterized_type(dtype: str): return ParameterizedType( interval_day=ParameterizedType.ParameterizedIntervalDay( - precision=to_integer_option(precision) + precision=to_integer_option(precision), + nullability=nullability ) ) elif dtype.startswith("any"): @@ -147,7 +166,7 @@ def to_parameterized_type(dtype: str): ) elif dtype.startswith("u!") or dtype == "geometry": return ParameterizedType( - user_defined=ParameterizedType.ParameterizedUserDefined() + user_defined=ParameterizedType.ParameterizedUserDefined(nullability=nullability) ) else: raise Exception(f"Unknown type - {dtype}") @@ -219,7 +238,7 @@ def __init__(self, uri: str, name: str, impl: Mapping[str, Any], anchor: int) -> if input_args := impl.get("args", []): for val in input_args: if typ := val.get("value"): - self.arguments.append(to_parameterized_type(typ.strip("?"))) + self.arguments.append(to_parameterized_type(typ)) self.normalized_inputs.append(normalize_substrait_type_names(typ)) elif arg_name := val.get("name", None): self.arguments.append(val.get("options")) From 58bb259e176bc84d6d57bc1c5e0bac8b21ef4992 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Tue, 11 Mar 2025 13:44:17 +0000 Subject: [PATCH 4/7] fix: make default uri prefix a constant --- src/substrait/function_registry.py | 48 +++++++++++++++++------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index 2dfd84e..fca19f6 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -11,6 +11,10 @@ import yaml import re + +DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" + + # mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names _normalized_key_names = { "i8": "i8", @@ -47,7 +51,7 @@ def normalize_substrait_type_names(typ: str) -> str: # First strip nullability marker typ = typ.strip("?").lower() # Strip type specifiers - typ = typ.split('<')[0] + typ = typ.split("<")[0] if typ.startswith("any"): return "any" @@ -67,7 +71,7 @@ def to_integer_option(txt: str): def to_parameterized_type(dtype: str): - if dtype.endswith('?'): + if dtype.endswith("?"): dtype = dtype[:-1] nullability = Type.NULLABILITY_NULLABLE else: @@ -96,13 +100,17 @@ def to_parameterized_type(dtype: str): elif dtype == "time": return ParameterizedType(time=Type.Time(nullability=nullability)) elif dtype == "interval_year": - return ParameterizedType(interval_year=Type.IntervalYear(nullability=nullability)) + return ParameterizedType( + interval_year=Type.IntervalYear(nullability=nullability) + ) elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"): (_, precision, scale, _) = re.split(r"\W+", dtype) return ParameterizedType( decimal=ParameterizedType.ParameterizedDecimal( - scale=to_integer_option(scale), precision=to_integer_option(precision), nullability=nullability + scale=to_integer_option(scale), + precision=to_integer_option(precision), + nullability=nullability, ) ) elif dtype.startswith("varchar"): @@ -110,8 +118,7 @@ def to_parameterized_type(dtype: str): return ParameterizedType( varchar=ParameterizedType.ParameterizedVarChar( - length=to_integer_option(length), - nullability=nullability + length=to_integer_option(length), nullability=nullability ) ) elif dtype.startswith("precision_timestamp"): @@ -119,8 +126,7 @@ def to_parameterized_type(dtype: str): return ParameterizedType( precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp( - precision=to_integer_option(precision), - nullability=nullability + precision=to_integer_option(precision), nullability=nullability ) ) elif dtype.startswith("precision_timestamp_tz"): @@ -128,8 +134,7 @@ def to_parameterized_type(dtype: str): return ParameterizedType( precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ( - precision=to_integer_option(precision), - nullability=nullability + precision=to_integer_option(precision), nullability=nullability ) ) elif dtype.startswith("fixedchar"): @@ -137,8 +142,7 @@ def to_parameterized_type(dtype: str): return ParameterizedType( fixed_char=ParameterizedType.ParameterizedFixedChar( - length=to_integer_option(length), - nullability=nullability + length=to_integer_option(length), nullability=nullability ) ) elif dtype == "string": @@ -147,8 +151,7 @@ def to_parameterized_type(dtype: str): inner_dtype = dtype[5:-1] return ParameterizedType( list=ParameterizedType.ParameterizedList( - type=to_parameterized_type(inner_dtype), - nullability=nullability + type=to_parameterized_type(inner_dtype), nullability=nullability ) ) elif dtype.startswith("interval_day"): @@ -156,8 +159,7 @@ def to_parameterized_type(dtype: str): return ParameterizedType( interval_day=ParameterizedType.ParameterizedIntervalDay( - precision=to_integer_option(precision), - nullability=nullability + precision=to_integer_option(precision), nullability=nullability ) ) elif dtype.startswith("any"): @@ -166,7 +168,9 @@ def to_parameterized_type(dtype: str): ) elif dtype.startswith("u!") or dtype == "geometry": return ParameterizedType( - user_defined=ParameterizedType.ParameterizedUserDefined(nullability=nullability) + user_defined=ParameterizedType.ParameterizedUserDefined( + nullability=nullability + ) ) else: raise Exception(f"Unknown type - {dtype}") @@ -226,7 +230,9 @@ def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict) class FunctionEntry: - def __init__(self, uri: str, name: str, impl: Mapping[str, Any], anchor: int) -> None: + def __init__( + self, uri: str, name: str, impl: Mapping[str, Any], anchor: int + ) -> None: self.name = name self.normalized_inputs: list = [] self.uri: str = uri @@ -283,7 +289,7 @@ def __init__(self) -> None: for fpath in importlib_files("substrait.extensions").glob( # type: ignore "functions*.yaml" ): - uri = f"https://github.com/substrait-io/substrait/blob/main/extensions/{fpath.name}" + uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}" self.uri_aliases[fpath.name] = uri self.register_extension_yaml(fpath, uri) @@ -302,7 +308,9 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None: for named_functions in definitions.values(): for function in named_functions: for impl in function.get("impls", []): - func = FunctionEntry(uri, function["name"], impl, next(self.id_generator)) + func = FunctionEntry( + uri, function["name"], impl, next(self.id_generator) + ) if ( func.uri in self._function_mapping and function["name"] in self._function_mapping[func.uri] From 9770bb467913ec850da27f0ca77ed1f3b81322c4 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 14 Mar 2025 08:06:06 +0000 Subject: [PATCH 5/7] feat: add scalar function nullability handling --- src/substrait/derivation_expression.py | 28 +++++-- src/substrait/function_registry.py | 55 ++++++++++-- tests/test_derivation_expression.py | 64 ++++++++++---- tests/test_function_registry.py | 111 ++++++++++++++++++++++--- 4 files changed, 215 insertions(+), 43 deletions(-) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 276d518..8b813a7 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -48,27 +48,39 @@ def _evaluate(x, values: dict): scalar_type = x.scalarType() parametrized_type = x.parameterizedType() if scalar_type: + nullability = ( + Type.NULLABILITY_NULLABLE if x.isnull else Type.NULLABILITY_REQUIRED + ) if isinstance(scalar_type, SubstraitTypeParser.I8Context): - return Type(i8=Type.I8()) + return Type(i8=Type.I8(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I16Context): - return Type(i16=Type.I16()) + return Type(i16=Type.I16(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I32Context): - return Type(i32=Type.I32()) + return Type(i32=Type.I32(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I64Context): - return Type(i64=Type.I64()) + return Type(i64=Type.I64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context): - return Type(fp32=Type.FP32()) + return Type(fp32=Type.FP32(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context): - return Type(fp64=Type.FP64()) + return Type(fp64=Type.FP64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): - return Type(bool=Type.Boolean()) + return Type(bool=Type.Boolean(nullability=nullability)) else: raise Exception(f"Unknown scalar type {type(scalar_type)}") elif parametrized_type: if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): precision = _evaluate(parametrized_type.precision, values) scale = _evaluate(parametrized_type.scale, values) - return Type(decimal=Type.Decimal(precision=precision, scale=scale)) + nullability = ( + Type.NULLABILITY_NULLABLE + if parametrized_type.isnull + else Type.NULLABILITY_REQUIRED + ) + return Type( + decimal=Type.Decimal( + precision=precision, scale=scale, nullability=nullability + ) + ) raise Exception(f"Unknown parametrized type {type(parametrized_type)}") else: raise Exception("either scalar_type or parametrized_type is required") diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index fca19f6..97d0ff2 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -48,10 +48,10 @@ def normalize_substrait_type_names(typ: str) -> str: - # First strip nullability marker - typ = typ.strip("?").lower() # Strip type specifiers typ = typ.split("<")[0] + # First strip nullability marker + typ = typ.strip("?").lower() if typ.startswith("any"): return "any" @@ -70,9 +70,10 @@ def to_integer_option(txt: str): ) +# TODO try using antlr grammar here as well def to_parameterized_type(dtype: str): - if dtype.endswith("?"): - dtype = dtype[:-1] + if "?" in dtype: + dtype = dtype.replace("?", "") nullability = Type.NULLABILITY_NULLABLE else: nullability = Type.NULLABILITY_REQUIRED @@ -193,11 +194,17 @@ def violates_integer_option( return False -def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict): +def covers( + dtype: Type, + parameterized_type: ParameterizedType, + parameters: dict, + check_nullability=False, +): expected_kind = parameterized_type.WhichOneof("kind") if expected_kind == "type_parameter": parameter_name = parameterized_type.type_parameter.name + # TODO figure out how to do nullability checks with "any" types if parameter_name == "any": return True else: @@ -211,11 +218,22 @@ def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict) parameters[parameter_name] = dtype return True + expected_nullability = parameterized_type.__getattribute__( + parameterized_type.WhichOneof("kind") + ).nullability + kind = dtype.WhichOneof("kind") if kind != expected_kind: return False + if ( + check_nullability + and dtype.__getattribute__(dtype.WhichOneof("kind")).nullability + != expected_nullability + ): + return False + if kind == "decimal": if violates_integer_option( dtype.decimal.scale, parameterized_type.decimal.scale, parameters @@ -225,7 +243,6 @@ def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict) return False # TODO handle all types - return True @@ -239,7 +256,7 @@ def __init__( self.anchor = anchor self.arguments = [] self.rtn = impl["return"] - self.nullability = impl.get("nullability", False) + self.nullability = impl.get("nullability", "MIRROR") self.variadic = impl.get("variadic", False) if input_args := impl.get("args", []): for val in input_args: @@ -273,10 +290,30 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: if y not in x: return None else: - if not covers(y, x, parameters): + if not covers( + y, x, parameters, check_nullability=self.nullability == "DISCRETE" + ): return None - return evaluate(self.rtn, parameters) + output_type = evaluate(self.rtn, parameters) + print(output_type) + + if self.nullability == "MIRROR": + sig_contains_nullable = any( + [ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if type(p) == Type + ] + ) + output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( + Type.NULLABILITY_NULLABLE + if sig_contains_nullable + else Type.NULLABILITY_REQUIRED + ) + + return output_type class FunctionRegistry: diff --git a/tests/test_derivation_expression.py b/tests/test_derivation_expression.py index 5df2e2d..4b11b3d 100644 --- a/tests/test_derivation_expression.py +++ b/tests/test_derivation_expression.py @@ -24,29 +24,59 @@ def test_ternary(): def test_multiline(): - assert ( - evaluate( - """temp = min(var, 7) + max(var, 7) + assert evaluate( + """temp = min(var, 7) + max(var, 7) decimal""", - {"var": 5}, + {"var": 5}, + ) == Type( + decimal=Type.Decimal( + precision=13, scale=11, nullability=Type.NULLABILITY_REQUIRED ) - == Type(decimal=Type.Decimal(precision=13, scale=11)) ) def test_simple_data_types(): - assert evaluate("i8") == Type(i8=Type.I8()) - assert evaluate("i16") == Type(i16=Type.I16()) - assert evaluate("i32") == Type(i32=Type.I32()) - assert evaluate("i64") == Type(i64=Type.I64()) - assert evaluate("fp32") == Type(fp32=Type.FP32()) - assert evaluate("fp64") == Type(fp64=Type.FP64()) - assert evaluate("boolean") == Type(bool=Type.Boolean()) + assert evaluate("i8") == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i16") == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i32") == Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i64") == Type(i64=Type.I64(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("fp32") == Type( + fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("fp64") == Type( + fp64=Type.FP64(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("boolean") == Type( + bool=Type.Boolean(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("i8?") == Type(i8=Type.I8(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i16?") == Type(i16=Type.I16(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i32?") == Type(i32=Type.I32(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i64?") == Type(i64=Type.I64(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("fp32?") == Type( + fp32=Type.FP32(nullability=Type.NULLABILITY_NULLABLE) + ) + assert evaluate("fp64?") == Type( + fp64=Type.FP64(nullability=Type.NULLABILITY_NULLABLE) + ) + assert evaluate("boolean?") == Type( + bool=Type.Boolean(nullability=Type.NULLABILITY_NULLABLE) + ) def test_data_type(): assert evaluate("decimal

", {"S": 10, "P": 20}) == Type( - decimal=Type.Decimal(precision=21, scale=11) + decimal=Type.Decimal( + precision=21, scale=11, nullability=Type.NULLABILITY_REQUIRED + ) + ) + + +def test_data_type_nullable(): + assert evaluate("decimal?

", {"S": 10, "P": 20}) == Type( + decimal=Type.Decimal( + precision=21, scale=11, nullability=Type.NULLABILITY_NULLABLE + ) ) @@ -59,7 +89,11 @@ def func(P1, S1, P2, S2): prec = min(init_prec, 38) scale_after_borrow = max(init_scale - delta, min_scale) scale = scale_after_borrow if init_prec > 38 else init_scale - return Type(decimal=Type.Decimal(precision=prec, scale=scale)) + return Type( + decimal=Type.Decimal( + precision=prec, scale=scale, nullability=Type.NULLABILITY_REQUIRED + ) + ) args = {"P1": 10, "S1": 8, "P2": 14, "S2": 2} @@ -78,4 +112,4 @@ def func(P1, S1, P2, S2): args, ) == func_eval - ) \ No newline at end of file + ) diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py index 20e4a26..85b0526 100644 --- a/tests/test_function_registry.py +++ b/tests/test_function_registry.py @@ -72,7 +72,35 @@ - name: x value: i8 return: i8 - + - name: "add_declared" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + nullability: DECLARED_OUTPUT + return: i8? + - name: "add_discrete" + description: "Add two values." + impls: + - args: + - name: x + value: i8? + - name: y + value: i8 + nullability: DISCRETE + return: i8? + - name: "test_decimal_discrete" + impls: + - args: + - name: x + value: decimal? + - name: y + value: decimal + nullability: DISCRETE + return: decimal? """ @@ -81,20 +109,46 @@ registry.register_extension_dict(yaml.safe_load(content), uri="test") -def i8(): - return Type(i8=Type.I8()) +def i8(nullable=False): + return Type( + i8=Type.I8( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) -def i16(): - return Type(i16=Type.I16()) +def i16(nullable=False): + return Type( + i16=Type.I16( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) -def bool(): - return Type(bool=Type.Boolean()) +def bool(nullable=False): + return Type( + bool=Type.Boolean( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) -def decimal(precision, scale): - return Type(decimal=Type.Decimal(scale=scale, precision=precision)) +def decimal(precision, scale, nullable=False): + return Type( + decimal=Type.Decimal( + scale=scale, + precision=precision, + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE, + ) + ) def test_non_existing_uri(): @@ -125,13 +179,13 @@ def test_non_existing_function_signature(): def test_exact_match(): assert registry.lookup_function( uri="test", function_name="add", signature=[i8(), i8()] - )[1] == Type(i8=Type.I8()) + )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) def test_wildcard_match(): assert registry.lookup_function( uri="test", function_name="add", signature=[i8(), i8(), bool()] - )[1] == Type(i16=Type.I16()) + )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) def test_wildcard_match_fails_with_constraits(): @@ -198,6 +252,14 @@ def test_decimal_violates_constraint(): ) +def test_decimal_happy_path_discrete(): + assert registry.lookup_function( + uri="test", + function_name="test_decimal_discrete", + signature=[decimal(10, 8, nullable=True), decimal(8, 6)], + )[1] == decimal(11, 7, nullable=True) + + def test_enum_with_valid_option(): assert ( registry.lookup_function( @@ -218,3 +280,30 @@ def test_enum_with_nonexistent_option(): ) is None ) + + +def test_function_with_nullable_args(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(nullable=True), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_declared_output_nullability(): + assert registry.lookup_function( + uri="test", function_name="add_declared", signature=[i8(), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability(): + assert registry.lookup_function( + uri="test", function_name="add_discrete", signature=[i8(nullable=True), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability(): + assert ( + registry.lookup_function( + uri="test", function_name="add_discrete", signature=[i8(), i8()] + ) + is None + ) From 00f563269a4fb36cfdd2d13c9cab8e25b270c793 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 14 Mar 2025 08:43:00 +0000 Subject: [PATCH 6/7] feat: add load_default_extensions arg to FunctionRegistry constructor --- src/substrait/function_registry.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index 97d0ff2..f5f9722 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -317,18 +317,19 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: class FunctionRegistry: - def __init__(self) -> None: + def __init__(self, load_default_extensions=True) -> None: self._function_mapping: dict = defaultdict(dict) - self.id_generator = itertools.count(1) + self._id_generator = itertools.count(1) - self.uri_aliases = {} + self._uri_aliases = {} - for fpath in importlib_files("substrait.extensions").glob( # type: ignore - "functions*.yaml" - ): - uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}" - self.uri_aliases[fpath.name] = uri - self.register_extension_yaml(fpath, uri) + if load_default_extensions: + for fpath in importlib_files("substrait.extensions").glob( # type: ignore + "functions*.yaml" + ): + uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}" + self._uri_aliases[fpath.name] = uri + self.register_extension_yaml(fpath, uri) def register_extension_yaml( self, @@ -346,7 +347,7 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None: for function in named_functions: for impl in function.get("impls", []): func = FunctionEntry( - uri, function["name"], impl, next(self.id_generator) + uri, function["name"], impl, next(self._id_generator) ) if ( func.uri in self._function_mapping @@ -360,7 +361,7 @@ def register_extension_dict(self, definitions: dict, uri: str) -> None: def lookup_function( self, uri: str, function_name: str, signature: tuple ) -> Optional[tuple[FunctionEntry, Type]]: - uri = self.uri_aliases.get(uri, uri) + uri = self._uri_aliases.get(uri, uri) if ( uri not in self._function_mapping From 739b7cf2b578415c0e698cea3bdb0afa862f3f1a Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Fri, 14 Mar 2025 17:44:37 +0000 Subject: [PATCH 7/7] feat: reimplement covers with antlr --- src/substrait/derivation_expression.py | 11 +- src/substrait/function_registry.py | 243 ++++++++----------------- tests/test_function_registry.py | 28 ++- 3 files changed, 111 insertions(+), 171 deletions(-) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 8b813a7..f5e68c1 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -37,7 +37,6 @@ def _evaluate(x, values: dict): elif type(x) == SubstraitTypeParser.FunctionCallContext: exprs = [_evaluate(e, values) for e in x.expr()] func = x.Identifier().symbol.text - if func == "min": return min(*exprs) elif func == "max": @@ -103,12 +102,18 @@ def _evaluate(x, values: dict): return _evaluate(x.finalType, values) elif type(x) == SubstraitTypeParser.TypeLiteralContext: return _evaluate(x.type_(), values) + elif type(x) == SubstraitTypeParser.NumericLiteralContext: + return int(str(x.Number())) else: raise Exception(f"Unknown token type {type(x)}") -def evaluate(x: str, values: Optional[dict] = None): +def _parse(x: str): lexer = SubstraitTypeLexer(InputStream(x)) stream = CommonTokenStream(lexer) parser = SubstraitTypeParser(stream) - return _evaluate(parser.expr(), values) + return parser.expr() + + +def evaluate(x: str, values: Optional[dict] = None): + return _evaluate(_parse(x), values) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index f5f9722..101f2d7 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -6,10 +6,9 @@ from collections.abc import Mapping from pathlib import Path from typing import Any, Optional, Union -from .derivation_expression import evaluate +from .derivation_expression import evaluate, _evaluate, _parse import yaml -import re DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" @@ -57,193 +56,104 @@ def normalize_substrait_type_names(typ: str) -> str: return "any" elif typ.startswith("u!"): return typ - else: + elif typ in _normalized_key_names: return _normalized_key_names[typ] - - -def to_integer_option(txt: str): - if txt.isnumeric(): - return ParameterizedType.IntegerOption(literal=int(txt)) else: - return ParameterizedType.IntegerOption( - parameter=ParameterizedType.IntegerParameter(name=txt) - ) + raise Exception(f"Unrecognized substrait type {typ}") -# TODO try using antlr grammar here as well -def to_parameterized_type(dtype: str): - if "?" in dtype: - dtype = dtype.replace("?", "") - nullability = Type.NULLABILITY_NULLABLE +def violates_integer_option(actual: int, option, parameters: dict): + if isinstance(option, SubstraitTypeParser.NumericLiteralContext): + return actual != int(str(option.Number())) + elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): + parameter_name = str(option.Identifier()) + if parameter_name in parameters and parameters[parameter_name] != actual: + return True + else: + parameters[parameter_name] = actual else: - nullability = Type.NULLABILITY_REQUIRED - - if dtype == "boolean": - return ParameterizedType(bool=Type.Boolean(nullability=nullability)) - elif dtype == "i8": - return ParameterizedType(i8=Type.I8(nullability=nullability)) - elif dtype == "i16": - return ParameterizedType(i16=Type.I16(nullability=nullability)) - elif dtype == "i32": - return ParameterizedType(i32=Type.I32(nullability=nullability)) - elif dtype == "i64": - return ParameterizedType(i64=Type.I64(nullability=nullability)) - elif dtype == "fp32": - return ParameterizedType(fp32=Type.FP32(nullability=nullability)) - elif dtype == "fp64": - return ParameterizedType(fp64=Type.FP64(nullability=nullability)) - elif dtype == "timestamp": - return ParameterizedType(timestamp=Type.Timestamp(nullability=nullability)) - elif dtype == "timestamp_tz": - return ParameterizedType(timestamp_tz=Type.TimestampTZ(nullability=nullability)) - elif dtype == "date": - return ParameterizedType(date=Type.Date(nullability=nullability)) - elif dtype == "time": - return ParameterizedType(time=Type.Time(nullability=nullability)) - elif dtype == "interval_year": - return ParameterizedType( - interval_year=Type.IntervalYear(nullability=nullability) - ) - elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"): - (_, precision, scale, _) = re.split(r"\W+", dtype) - - return ParameterizedType( - decimal=ParameterizedType.ParameterizedDecimal( - scale=to_integer_option(scale), - precision=to_integer_option(precision), - nullability=nullability, - ) + raise Exception( + f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" ) - elif dtype.startswith("varchar"): - (_, length, _) = re.split(r"\W+", dtype) - return ParameterizedType( - varchar=ParameterizedType.ParameterizedVarChar( - length=to_integer_option(length), nullability=nullability - ) - ) - elif dtype.startswith("precision_timestamp"): - (_, precision, _) = re.split(r"\W+", dtype) + return False - return ParameterizedType( - precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp( - precision=to_integer_option(precision), nullability=nullability - ) - ) - elif dtype.startswith("precision_timestamp_tz"): - (_, precision, _) = re.split(r"\W+", dtype) - return ParameterizedType( - precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ( - precision=to_integer_option(precision), nullability=nullability - ) - ) - elif dtype.startswith("fixedchar"): - (_, length, _) = re.split(r"\W+", dtype) +from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser - return ParameterizedType( - fixed_char=ParameterizedType.ParameterizedFixedChar( - length=to_integer_option(length), nullability=nullability - ) - ) - elif dtype == "string": - return ParameterizedType(string=Type.String(nullability=nullability)) - elif dtype.startswith("list"): - inner_dtype = dtype[5:-1] - return ParameterizedType( - list=ParameterizedType.ParameterizedList( - type=to_parameterized_type(inner_dtype), nullability=nullability - ) - ) - elif dtype.startswith("interval_day"): - (_, precision, _) = re.split(r"\W+", dtype) - return ParameterizedType( - interval_day=ParameterizedType.ParameterizedIntervalDay( - precision=to_integer_option(precision), nullability=nullability - ) - ) - elif dtype.startswith("any"): - return ParameterizedType( - type_parameter=ParameterizedType.TypeParameter(name=dtype) - ) - elif dtype.startswith("u!") or dtype == "geometry": - return ParameterizedType( - user_defined=ParameterizedType.ParameterizedUserDefined( - nullability=nullability - ) - ) +def types_equal(type1: Type, type2: Type, check_nullability=False): + if check_nullability: + return type1 == type2 else: - raise Exception(f"Unknown type - {dtype}") + x, y = Type(), Type() + x.CopyFrom(type1) + y.CopyFrom(type2) + x.__getattribute__( + x.WhichOneof("kind") + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + y.__getattribute__( + y.WhichOneof("kind") + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + return x == y -def violates_integer_option( - actual: int, option: ParameterizedType.IntegerOption, parameters: dict +def covers( + covered: Type, + covering: SubstraitTypeParser.TypeLiteralContext, + parameters: dict, + check_nullability=False, ): - integer_type = option.WhichOneof("integer_type") + if isinstance(covering, SubstraitTypeParser.TypeParamContext): + parameter_name = str(covering.Identifier()) - if integer_type == "literal" and actual != option.literal: - return True - else: - parameter_name = option.parameter.name - if parameter_name in parameters and parameters[parameter_name] != actual: - return True + if parameter_name in parameters: + covering = parameters[parameter_name] + + return types_equal(covering, covered, check_nullability) else: - parameters[parameter_name] = actual + parameters[parameter_name] = covered + return True - return False + covering = covering.type_() + scalar_type = covering.scalarType() + if scalar_type: + covering = _evaluate(covering, {}) + return types_equal(covering, covered, check_nullability) + parameterized_type = covering.parameterizedType() + if parameterized_type: + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + if covered.WhichOneof("kind") != "decimal": + return False -def covers( - dtype: Type, - parameterized_type: ParameterizedType, - parameters: dict, - check_nullability=False, -): - expected_kind = parameterized_type.WhichOneof("kind") + nullability = ( + Type.NULLABILITY_NULLABLE + if parameterized_type.isnull + else Type.NULLABILITY_REQUIRED + ) - if expected_kind == "type_parameter": - parameter_name = parameterized_type.type_parameter.name - # TODO figure out how to do nullability checks with "any" types - if parameter_name == "any": - return True - else: - if parameter_name in parameters and parameters[ - parameter_name - ].SerializeToString(deterministic=True) != dtype.SerializeToString( - deterministic=True + if ( + check_nullability + and nullability + != covered.__getattribute__(covered.WhichOneof("kind")).nullability ): return False - else: - parameters[parameter_name] = dtype - return True - - expected_nullability = parameterized_type.__getattribute__( - parameterized_type.WhichOneof("kind") - ).nullability - - kind = dtype.WhichOneof("kind") - - if kind != expected_kind: - return False - - if ( - check_nullability - and dtype.__getattribute__(dtype.WhichOneof("kind")).nullability - != expected_nullability - ): - return False - - if kind == "decimal": - if violates_integer_option( - dtype.decimal.scale, parameterized_type.decimal.scale, parameters - ) or violates_integer_option( - dtype.decimal.precision, parameterized_type.decimal.precision, parameters - ): - return False - # TODO handle all types - return True + return not ( + violates_integer_option( + covered.decimal.scale, parameterized_type.scale, parameters + ) + or violates_integer_option( + covered.decimal.precision, parameterized_type.precision, parameters + ) + ) + else: + raise Exception(f"Unhandled type {type(parameterized_type)}") + + any_type = covering.anyType() + if any_type: + return True class FunctionEntry: @@ -261,7 +171,7 @@ def __init__( if input_args := impl.get("args", []): for val in input_args: if typ := val.get("value"): - self.arguments.append(to_parameterized_type(typ)) + self.arguments.append(_parse(typ)) self.normalized_inputs.append(normalize_substrait_type_names(typ)) elif arg_name := val.get("name", None): self.arguments.append(val.get("options")) @@ -296,7 +206,6 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: return None output_type = evaluate(self.rtn, parameters) - print(output_type) if self.nullability == "MIRROR": sig_contains_nullable = any( diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py index 85b0526..ef7387e 100644 --- a/tests/test_function_registry.py +++ b/tests/test_function_registry.py @@ -1,7 +1,8 @@ import yaml from substrait.gen.proto.type_pb2 import Type -from substrait.function_registry import FunctionRegistry +from substrait.function_registry import FunctionRegistry, covers +from substrait.derivation_expression import _parse content = """%YAML 1.2 --- @@ -307,3 +308,28 @@ def test_function_with_discrete_nullability(): ) is None ) + + +def test_covers(): + params = {} + assert covers(i8(), _parse("i8"), params) + assert params == {} + + +def test_covers_nullability(): + assert not covers(i8(nullable=True), _parse("i8"), {}, check_nullability=True) + assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) + + +def test_covers_decimal(): + assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) + + +def test_covers_decimal_happy_path(): + params = {} + assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) + assert params == {"A": 8} + + +def test_covers_any(): + assert covers(decimal(10, 8), _parse("any"), {})