Skip to content

Commit 2df0cd0

Browse files
authored
feat: add type inference (#67)
Adds helper functions that infer expression types and rel schemas from substrait. The implementation is not yet meant to be exhaustive, it covers a subset of types/expressions/rels.
1 parent 95f7b7d commit 2df0cd0

File tree

3 files changed

+963
-0
lines changed

3 files changed

+963
-0
lines changed

src/substrait/type_inference.py

+336
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.type_pb2 as stt
3+
4+
5+
def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type:
6+
literal_type = literal.WhichOneof("literal_type")
7+
8+
nullability = (
9+
stt.Type.Nullability.NULLABILITY_NULLABLE
10+
if literal.nullable
11+
else stt.Type.Nullability.NULLABILITY_REQUIRED
12+
)
13+
14+
if literal_type == "boolean":
15+
return stt.Type(bool=stt.Type.Boolean(nullability=nullability))
16+
elif literal_type == "i8":
17+
return stt.Type(i8=stt.Type.I8(nullability=nullability))
18+
elif literal_type == "i16":
19+
return stt.Type(i16=stt.Type.I16(nullability=nullability))
20+
elif literal_type == "i32":
21+
return stt.Type(i32=stt.Type.I32(nullability=nullability))
22+
elif literal_type == "i64":
23+
return stt.Type(i64=stt.Type.I64(nullability=nullability))
24+
elif literal_type == "fp32":
25+
return stt.Type(fp32=stt.Type.FP32(nullability=nullability))
26+
elif literal_type == "fp64":
27+
return stt.Type(fp64=stt.Type.FP64(nullability=nullability))
28+
elif literal_type == "string":
29+
return stt.Type(string=stt.Type.String(nullability=nullability))
30+
elif literal_type == "binary":
31+
return stt.Type(binary=stt.Type.Binary(nullability=nullability))
32+
elif literal_type == "timestamp":
33+
return stt.Type(timestamp=stt.Type.Timestamp(nullability=nullability))
34+
elif literal_type == "date":
35+
return stt.Type(date=stt.Type.Date(nullability=nullability))
36+
elif literal_type == "time":
37+
return stt.Type(time=stt.Type.Time(nullability=nullability))
38+
elif literal_type == "interval_year_to_month":
39+
return stt.Type(interval_year=stt.Type.IntervalYear(nullability=nullability))
40+
elif literal_type == "interval_day_to_second":
41+
return stt.Type(
42+
interval_day=stt.Type.IntervalDay(
43+
precision=literal.interval_day_to_second.precision,
44+
nullability=nullability,
45+
)
46+
)
47+
elif literal_type == "interval_compound":
48+
return stt.Type(
49+
interval_compound=stt.Type.IntervalCompound(
50+
nullability=nullability,
51+
precision=literal.interval_compound.interval_day_to_second.precision,
52+
)
53+
)
54+
elif literal_type == "fixed_char":
55+
return stt.Type(
56+
fixed_char=stt.Type.FixedChar(
57+
length=len(literal.fixed_char), nullability=nullability
58+
)
59+
)
60+
elif literal_type == "var_char":
61+
return stt.Type(
62+
varchar=stt.Type.VarChar(
63+
length=literal.var_char.length, nullability=nullability
64+
)
65+
)
66+
elif literal_type == "fixed_binary":
67+
return stt.Type(
68+
fixed_binary=stt.Type.FixedBinary(
69+
length=len(literal.fixed_binary), nullability=nullability
70+
)
71+
)
72+
elif literal_type == "decimal":
73+
return stt.Type(
74+
decimal=stt.Type.Decimal(
75+
scale=literal.decimal.scale,
76+
precision=literal.decimal.precision,
77+
nullability=nullability,
78+
)
79+
)
80+
elif literal_type == "precision_timestamp":
81+
return stt.Type(
82+
precision_timestamp=stt.Type.PrecisionTimestamp(
83+
precision=literal.precision_timestamp.precision, nullability=nullability
84+
)
85+
)
86+
elif literal_type == "precision_timestamp_tz":
87+
return stt.Type(
88+
precision_timestamp_tz=stt.Type.PrecisionTimestampTZ(
89+
precision=literal.precision_timestamp_tz.precision,
90+
nullability=nullability,
91+
)
92+
)
93+
elif literal_type == "struct":
94+
return stt.Type(
95+
struct=stt.Type.Struct(
96+
types=[infer_literal_type(f) for f in literal.struct.fields],
97+
nullability=nullability,
98+
)
99+
)
100+
elif literal_type == "map":
101+
return stt.Type(
102+
map=stt.Type.Map(
103+
key=infer_literal_type(literal.map.key_values[0].key),
104+
value=infer_literal_type(literal.map.key_values[0].value),
105+
nullability=nullability,
106+
)
107+
)
108+
elif literal_type == "timestamp_tz":
109+
return stt.Type(timestamp_tz=stt.Type.TimestampTZ(nullability=nullability))
110+
elif literal_type == "uuid":
111+
return stt.Type(uuid=stt.Type.UUID(nullability=nullability))
112+
elif literal_type == "null":
113+
return literal.null
114+
elif literal_type == "list":
115+
return stt.Type(
116+
list=stt.Type.List(
117+
type=infer_literal_type(literal.list.values[0]), nullability=nullability
118+
)
119+
)
120+
elif literal_type == "empty_list":
121+
return stt.Type(list=literal.empty_list)
122+
elif literal_type == "empty_map":
123+
return stt.Type(map=literal.empty_map)
124+
else:
125+
raise Exception(f"Unknown literal_type {literal_type}")
126+
127+
128+
def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type:
129+
nested_type = nested.WhichOneof("nested_type")
130+
131+
nullability = (
132+
stt.Type.Nullability.NULLABILITY_NULLABLE
133+
if nested.nullable
134+
else stt.Type.Nullability.NULLABILITY_REQUIRED
135+
)
136+
137+
if nested_type == "struct":
138+
return stt.Type(
139+
struct=stt.Type.Struct(
140+
types=[infer_expression_type(f) for f in nested.struct.fields],
141+
nullability=nullability,
142+
)
143+
)
144+
elif nested_type == "list":
145+
return stt.Type(
146+
list=stt.Type.List(
147+
type=infer_expression_type(nested.list.values[0]),
148+
nullability=nullability,
149+
)
150+
)
151+
elif nested_type == "map":
152+
return stt.Type(
153+
map=stt.Type.Map(
154+
key=infer_expression_type(nested.map.key_values[0].key),
155+
value=infer_expression_type(nested.map.key_values[0].value),
156+
nullability=nullability,
157+
)
158+
)
159+
else:
160+
raise Exception(f"Unknown nested_type {nested_type}")
161+
162+
163+
def infer_expression_type(
164+
expression: stalg.Expression, parent_schema: stt.Type.Struct
165+
) -> stt.Type:
166+
rex_type = expression.WhichOneof("rex_type")
167+
if rex_type == "selection":
168+
root_type = expression.selection.WhichOneof("root_type")
169+
assert root_type == "root_reference"
170+
171+
reference_type = expression.selection.WhichOneof("reference_type")
172+
173+
if reference_type == "direct_reference":
174+
segment = expression.selection.direct_reference
175+
176+
segment_reference_type = segment.WhichOneof("reference_type")
177+
178+
if segment_reference_type == "struct_field":
179+
return parent_schema.types[segment.struct_field.field]
180+
else:
181+
raise Exception(f"Unknown reference_type {reference_type}")
182+
else:
183+
raise Exception(f"Unknown reference_type {reference_type}")
184+
185+
elif rex_type == "literal":
186+
return infer_literal_type(expression.literal)
187+
elif rex_type == "scalar_function":
188+
return expression.scalar_function.output_type
189+
elif rex_type == "window_function":
190+
return expression.window_function.output_type
191+
elif rex_type == "if_then":
192+
return infer_expression_type(expression.if_then.ifs[0].then)
193+
elif rex_type == "switch_expression":
194+
return infer_expression_type(expression.switch_expression.ifs[0].then)
195+
elif rex_type == "cast":
196+
return expression.cast.type
197+
elif rex_type == "singular_or_list" or rex_type == "multi_or_list":
198+
return stt.Type(
199+
bool=stt.Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE)
200+
)
201+
elif rex_type == "nested":
202+
return infer_nested_type(expression.nested)
203+
elif rex_type == "subquery":
204+
subquery_type = expression.subquery.WhichOneof("subquery_type")
205+
206+
if subquery_type == "scalar":
207+
scalar_rel = infer_rel_schema(expression.subquery.scalar.input)
208+
return scalar_rel.types[0]
209+
elif (
210+
subquery_type == "in_predicate"
211+
or subquery_type == "set_comparison"
212+
or subquery_type == "set_predicate"
213+
):
214+
stt.Type.Boolean(
215+
nullability=stt.Type.Nullability.NULLABILITY_NULLABLE
216+
) # can this be a null?
217+
else:
218+
raise Exception(f"Unknown subquery_type {subquery_type}")
219+
else:
220+
raise Exception(f"Unknown rex_type {rex_type}")
221+
222+
223+
def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
224+
rel_type = rel.WhichOneof("rel_type")
225+
226+
if rel_type == "read":
227+
(common, struct) = (rel.read.common, rel.read.base_schema.struct)
228+
elif rel_type == "filter":
229+
(common, struct) = (rel.filter.common, infer_rel_schema(rel.filter.input))
230+
elif rel_type == "fetch":
231+
(common, struct) = (rel.fetch.common, infer_rel_schema(rel.fetch.input))
232+
elif rel_type == "aggregate":
233+
parent_schema = infer_rel_schema(rel.aggregate.input)
234+
grouping_types = [
235+
infer_expression_type(g, parent_schema)
236+
for g in rel.aggregate.grouping_expressions
237+
]
238+
measure_types = [m.measure.output_type for m in rel.aggregate.measures]
239+
240+
grouping_identifier_types = (
241+
[]
242+
if len(rel.aggregate.groupings) <= 1
243+
else [stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED))]
244+
)
245+
246+
raw_schema = stt.Type.Struct(
247+
types=grouping_types + measure_types + grouping_identifier_types,
248+
nullability=parent_schema.nullability,
249+
)
250+
251+
(common, struct) = (rel.aggregate.common, raw_schema)
252+
elif rel_type == "sort":
253+
(common, struct) = (rel.sort.common, infer_rel_schema(rel.sort.input))
254+
elif rel_type == "project":
255+
parent_schema = infer_rel_schema(rel.project.input)
256+
expression_types = [
257+
infer_expression_type(e, parent_schema) for e in rel.project.expressions
258+
]
259+
raw_schema = stt.Type.Struct(
260+
types=list(parent_schema.types) + expression_types,
261+
nullability=parent_schema.nullability,
262+
)
263+
264+
(common, struct) = (rel.project.common, raw_schema)
265+
elif rel_type == "set":
266+
(common, struct) = (rel.fetch.common, infer_rel_schema(rel.set.inputs[0]))
267+
elif rel_type == "cross":
268+
left_schema = infer_rel_schema(rel.cross.left)
269+
right_schema = infer_rel_schema(rel.cross.right)
270+
271+
raw_schema = stt.Type.Struct(
272+
types=list(left_schema.types) + list(right_schema.types),
273+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
274+
)
275+
276+
(common, struct) = (rel.cross.common, raw_schema)
277+
elif rel_type == "join":
278+
if rel.join.type in [
279+
stalg.JoinRel.JOIN_TYPE_INNER,
280+
stalg.JoinRel.JOIN_TYPE_OUTER,
281+
stalg.JoinRel.JOIN_TYPE_LEFT,
282+
stalg.JoinRel.JOIN_TYPE_RIGHT,
283+
stalg.JoinRel.JOIN_TYPE_LEFT_SINGLE,
284+
stalg.JoinRel.JOIN_TYPE_RIGHT_SINGLE,
285+
]:
286+
raw_schema = stt.Type.Struct(
287+
types=list(infer_rel_schema(rel.join.left).types)
288+
+ list(infer_rel_schema(rel.join.right).types),
289+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
290+
)
291+
elif rel.join.type in [
292+
stalg.JoinRel.JOIN_TYPE_LEFT_ANTI,
293+
stalg.JoinRel.JOIN_TYPE_LEFT_SEMI,
294+
]:
295+
raw_schema = stt.Type.Struct(
296+
types=infer_rel_schema(rel.join.left).types,
297+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
298+
)
299+
elif rel.join.type in [
300+
stalg.JoinRel.JOIN_TYPE_RIGHT_ANTI,
301+
stalg.JoinRel.JOIN_TYPE_RIGHT_SEMI,
302+
]:
303+
raw_schema = stt.Type.Struct(
304+
types=infer_rel_schema(rel.join.right).types,
305+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
306+
)
307+
elif rel.join.type in [
308+
stalg.JoinRel.JOIN_TYPE_LEFT_MARK,
309+
stalg.JoinRel.JOIN_TYPE_RIGHT_MARK,
310+
]:
311+
raw_schema = stt.Type.Struct(
312+
types=list(infer_rel_schema(rel.join.left).types)
313+
+ list(infer_rel_schema(rel.join.right).types)
314+
+ [
315+
stt.Type(
316+
bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)
317+
)
318+
],
319+
nullability=stt.Type.Nullability.NULLABILITY_REQUIRED,
320+
)
321+
else:
322+
raise Exception(f"Unhandled join_type {rel.join.type}")
323+
324+
(common, struct) = (rel.join.common, raw_schema)
325+
else:
326+
raise Exception(f"Unhandled rel_type {rel_type}")
327+
328+
emit_kind = common.WhichOneof("emit_kind") or "direct"
329+
330+
if emit_kind == "direct":
331+
return struct
332+
else:
333+
return stt.Type.Struct(
334+
types=[struct.types[i] for i in common.emit.output_mapping],
335+
nullability=struct.nullability,
336+
)

0 commit comments

Comments
 (0)