|
| 1 | +from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType |
| 2 | +from substrait.gen.proto.type_pb2 import Type |
| 3 | +from importlib.resources import files as importlib_files |
| 4 | +import itertools |
| 5 | +from collections import defaultdict |
| 6 | +from collections.abc import Mapping |
| 7 | +from pathlib import Path |
| 8 | +from typing import Any, Optional, Union |
| 9 | +from .derivation_expression import evaluate, _evaluate, _parse |
| 10 | + |
| 11 | +import yaml |
| 12 | + |
| 13 | + |
| 14 | +DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" |
| 15 | + |
| 16 | + |
| 17 | +# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names |
| 18 | +_normalized_key_names = { |
| 19 | + "i8": "i8", |
| 20 | + "i16": "i16", |
| 21 | + "i32": "i32", |
| 22 | + "i64": "i64", |
| 23 | + "fp32": "fp32", |
| 24 | + "fp64": "fp64", |
| 25 | + "string": "str", |
| 26 | + "binary": "vbin", |
| 27 | + "boolean": "bool", |
| 28 | + "timestamp": "ts", |
| 29 | + "timestamp_tz": "tstz", |
| 30 | + "date": "date", |
| 31 | + "time": "time", |
| 32 | + "interval_year": "iyear", |
| 33 | + "interval_day": "iday", |
| 34 | + "interval_compound": "icompound", |
| 35 | + "uuid": "uuid", |
| 36 | + "fixedchar": "fchar", |
| 37 | + "varchar": "vchar", |
| 38 | + "fixedbinary": "fbin", |
| 39 | + "decimal": "dec", |
| 40 | + "precision_time": "pt", |
| 41 | + "precision_timestamp": "pts", |
| 42 | + "precision_timestamp_tz": "ptstz", |
| 43 | + "struct": "struct", |
| 44 | + "list": "list", |
| 45 | + "map": "map", |
| 46 | +} |
| 47 | + |
| 48 | + |
| 49 | +def normalize_substrait_type_names(typ: str) -> str: |
| 50 | + # Strip type specifiers |
| 51 | + typ = typ.split("<")[0] |
| 52 | + # First strip nullability marker |
| 53 | + typ = typ.strip("?").lower() |
| 54 | + |
| 55 | + if typ.startswith("any"): |
| 56 | + return "any" |
| 57 | + elif typ.startswith("u!"): |
| 58 | + return typ |
| 59 | + elif typ in _normalized_key_names: |
| 60 | + return _normalized_key_names[typ] |
| 61 | + else: |
| 62 | + raise Exception(f"Unrecognized substrait type {typ}") |
| 63 | + |
| 64 | + |
| 65 | +def violates_integer_option(actual: int, option, parameters: dict): |
| 66 | + if isinstance(option, SubstraitTypeParser.NumericLiteralContext): |
| 67 | + return actual != int(str(option.Number())) |
| 68 | + elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): |
| 69 | + parameter_name = str(option.Identifier()) |
| 70 | + if parameter_name in parameters and parameters[parameter_name] != actual: |
| 71 | + return True |
| 72 | + else: |
| 73 | + parameters[parameter_name] = actual |
| 74 | + else: |
| 75 | + raise Exception( |
| 76 | + f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" |
| 77 | + ) |
| 78 | + |
| 79 | + return False |
| 80 | + |
| 81 | + |
| 82 | +from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser |
| 83 | + |
| 84 | + |
| 85 | +def types_equal(type1: Type, type2: Type, check_nullability=False): |
| 86 | + if check_nullability: |
| 87 | + return type1 == type2 |
| 88 | + else: |
| 89 | + x, y = Type(), Type() |
| 90 | + x.CopyFrom(type1) |
| 91 | + y.CopyFrom(type2) |
| 92 | + x.__getattribute__( |
| 93 | + x.WhichOneof("kind") |
| 94 | + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED |
| 95 | + y.__getattribute__( |
| 96 | + y.WhichOneof("kind") |
| 97 | + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED |
| 98 | + return x == y |
| 99 | + |
| 100 | + |
| 101 | +def covers( |
| 102 | + covered: Type, |
| 103 | + covering: SubstraitTypeParser.TypeLiteralContext, |
| 104 | + parameters: dict, |
| 105 | + check_nullability=False, |
| 106 | +): |
| 107 | + if isinstance(covering, SubstraitTypeParser.TypeParamContext): |
| 108 | + parameter_name = str(covering.Identifier()) |
| 109 | + |
| 110 | + if parameter_name in parameters: |
| 111 | + covering = parameters[parameter_name] |
| 112 | + |
| 113 | + return types_equal(covering, covered, check_nullability) |
| 114 | + else: |
| 115 | + parameters[parameter_name] = covered |
| 116 | + return True |
| 117 | + |
| 118 | + covering = covering.type_() |
| 119 | + scalar_type = covering.scalarType() |
| 120 | + if scalar_type: |
| 121 | + covering = _evaluate(covering, {}) |
| 122 | + return types_equal(covering, covered, check_nullability) |
| 123 | + |
| 124 | + parameterized_type = covering.parameterizedType() |
| 125 | + if parameterized_type: |
| 126 | + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): |
| 127 | + if covered.WhichOneof("kind") != "decimal": |
| 128 | + return False |
| 129 | + |
| 130 | + nullability = ( |
| 131 | + Type.NULLABILITY_NULLABLE |
| 132 | + if parameterized_type.isnull |
| 133 | + else Type.NULLABILITY_REQUIRED |
| 134 | + ) |
| 135 | + |
| 136 | + if ( |
| 137 | + check_nullability |
| 138 | + and nullability |
| 139 | + != covered.__getattribute__(covered.WhichOneof("kind")).nullability |
| 140 | + ): |
| 141 | + return False |
| 142 | + |
| 143 | + return not ( |
| 144 | + violates_integer_option( |
| 145 | + covered.decimal.scale, parameterized_type.scale, parameters |
| 146 | + ) |
| 147 | + or violates_integer_option( |
| 148 | + covered.decimal.precision, parameterized_type.precision, parameters |
| 149 | + ) |
| 150 | + ) |
| 151 | + else: |
| 152 | + raise Exception(f"Unhandled type {type(parameterized_type)}") |
| 153 | + |
| 154 | + any_type = covering.anyType() |
| 155 | + if any_type: |
| 156 | + return True |
| 157 | + |
| 158 | + |
| 159 | +class FunctionEntry: |
| 160 | + def __init__( |
| 161 | + self, uri: str, name: str, impl: Mapping[str, Any], anchor: int |
| 162 | + ) -> None: |
| 163 | + self.name = name |
| 164 | + self.normalized_inputs: list = [] |
| 165 | + self.uri: str = uri |
| 166 | + self.anchor = anchor |
| 167 | + self.arguments = [] |
| 168 | + self.rtn = impl["return"] |
| 169 | + self.nullability = impl.get("nullability", "MIRROR") |
| 170 | + self.variadic = impl.get("variadic", False) |
| 171 | + if input_args := impl.get("args", []): |
| 172 | + for val in input_args: |
| 173 | + if typ := val.get("value"): |
| 174 | + self.arguments.append(_parse(typ)) |
| 175 | + self.normalized_inputs.append(normalize_substrait_type_names(typ)) |
| 176 | + elif arg_name := val.get("name", None): |
| 177 | + self.arguments.append(val.get("options")) |
| 178 | + self.normalized_inputs.append("req") |
| 179 | + |
| 180 | + def __repr__(self) -> str: |
| 181 | + return f"{self.name}:{'_'.join(self.normalized_inputs)}" |
| 182 | + |
| 183 | + def satisfies_signature(self, signature: tuple) -> Optional[str]: |
| 184 | + if self.variadic: |
| 185 | + min_args_allowed = self.variadic.get("min", 0) |
| 186 | + if len(signature) < min_args_allowed: |
| 187 | + return None |
| 188 | + inputs = [self.arguments[0]] * len(signature) |
| 189 | + else: |
| 190 | + inputs = self.arguments |
| 191 | + if len(inputs) != len(signature): |
| 192 | + return None |
| 193 | + |
| 194 | + zipped_args = list(zip(inputs, signature)) |
| 195 | + |
| 196 | + parameters = {} |
| 197 | + |
| 198 | + for x, y in zipped_args: |
| 199 | + if type(y) == str: |
| 200 | + if y not in x: |
| 201 | + return None |
| 202 | + else: |
| 203 | + if not covers( |
| 204 | + y, x, parameters, check_nullability=self.nullability == "DISCRETE" |
| 205 | + ): |
| 206 | + return None |
| 207 | + |
| 208 | + output_type = evaluate(self.rtn, parameters) |
| 209 | + |
| 210 | + if self.nullability == "MIRROR": |
| 211 | + sig_contains_nullable = any( |
| 212 | + [ |
| 213 | + p.__getattribute__(p.WhichOneof("kind")).nullability |
| 214 | + == Type.NULLABILITY_NULLABLE |
| 215 | + for p in signature |
| 216 | + if type(p) == Type |
| 217 | + ] |
| 218 | + ) |
| 219 | + output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( |
| 220 | + Type.NULLABILITY_NULLABLE |
| 221 | + if sig_contains_nullable |
| 222 | + else Type.NULLABILITY_REQUIRED |
| 223 | + ) |
| 224 | + |
| 225 | + return output_type |
| 226 | + |
| 227 | + |
| 228 | +class FunctionRegistry: |
| 229 | + def __init__(self, load_default_extensions=True) -> None: |
| 230 | + self._function_mapping: dict = defaultdict(dict) |
| 231 | + self._id_generator = itertools.count(1) |
| 232 | + |
| 233 | + self._uri_aliases = {} |
| 234 | + |
| 235 | + if load_default_extensions: |
| 236 | + for fpath in importlib_files("substrait.extensions").glob( # type: ignore |
| 237 | + "functions*.yaml" |
| 238 | + ): |
| 239 | + uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}" |
| 240 | + self._uri_aliases[fpath.name] = uri |
| 241 | + self.register_extension_yaml(fpath, uri) |
| 242 | + |
| 243 | + def register_extension_yaml( |
| 244 | + self, |
| 245 | + fname: Union[str, Path], |
| 246 | + uri: str, |
| 247 | + ) -> None: |
| 248 | + fname = Path(fname) |
| 249 | + with open(fname) as f: # type: ignore |
| 250 | + extension_definitions = yaml.safe_load(f) |
| 251 | + |
| 252 | + self.register_extension_dict(extension_definitions, uri) |
| 253 | + |
| 254 | + def register_extension_dict(self, definitions: dict, uri: str) -> None: |
| 255 | + for named_functions in definitions.values(): |
| 256 | + for function in named_functions: |
| 257 | + for impl in function.get("impls", []): |
| 258 | + func = FunctionEntry( |
| 259 | + uri, function["name"], impl, next(self._id_generator) |
| 260 | + ) |
| 261 | + if ( |
| 262 | + func.uri in self._function_mapping |
| 263 | + and function["name"] in self._function_mapping[func.uri] |
| 264 | + ): |
| 265 | + self._function_mapping[func.uri][function["name"]].append(func) |
| 266 | + else: |
| 267 | + self._function_mapping[func.uri][function["name"]] = [func] |
| 268 | + |
| 269 | + # TODO add an optional return type check |
| 270 | + def lookup_function( |
| 271 | + self, uri: str, function_name: str, signature: tuple |
| 272 | + ) -> Optional[tuple[FunctionEntry, Type]]: |
| 273 | + uri = self._uri_aliases.get(uri, uri) |
| 274 | + |
| 275 | + if ( |
| 276 | + uri not in self._function_mapping |
| 277 | + or function_name not in self._function_mapping[uri] |
| 278 | + ): |
| 279 | + return None |
| 280 | + functions = self._function_mapping[uri][function_name] |
| 281 | + for f in functions: |
| 282 | + assert isinstance(f, FunctionEntry) |
| 283 | + rtn = f.satisfies_signature(signature) |
| 284 | + if rtn is not None: |
| 285 | + return (f, rtn) |
| 286 | + |
| 287 | + return None |
0 commit comments