Skip to content

Commit 444da04

Browse files
committed
fix: add missing output names
1 parent 0ab87e9 commit 444da04

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

src/substrait/extended_expression.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ def resolve(base_schema: stp.NamedStruct) -> stee.ExtendedExpression:
1010
column_index = list(base_schema.names).index(name)
1111
lengths = [type_num_names(t) for t in base_schema.struct.types]
1212
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]
13+
field_index = flat_indices.index(column_index)
14+
15+
names_start = flat_indices[field_index]
16+
names_end = (
17+
flat_indices[field_index + 1]
18+
if len(flat_indices) > field_index + 1
19+
else None
20+
)
1321

1422
return stee.ExtendedExpression(
1523
referred_expr=[
@@ -19,11 +27,12 @@ def resolve(base_schema: stp.NamedStruct) -> stee.ExtendedExpression:
1927
root_reference=stalg.Expression.FieldReference.RootReference(),
2028
direct_reference=stalg.Expression.ReferenceSegment(
2129
struct_field=stalg.Expression.ReferenceSegment.StructField(
22-
field=flat_indices.index(column_index)
30+
field=field_index
2331
)
2432
),
2533
)
26-
)
34+
),
35+
output_names=list(base_schema.names)[names_start:names_end],
2736
)
2837
],
2938
base_schema=base_schema,

tests/test_extended_expression.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ def test_column_no_nesting():
5555
)
5656
),
5757
)
58-
)
58+
),
59+
output_names=["description"],
5960
)
6061
],
6162
base_schema=named_struct,
@@ -75,7 +76,29 @@ def test_column_nesting():
7576
)
7677
),
7778
)
78-
)
79+
),
80+
output_names=["order_total"],
81+
)
82+
],
83+
base_schema=nested_named_struct,
84+
)
85+
86+
87+
def test_column_nested_struct():
88+
assert column("shop_details")(nested_named_struct) == stee.ExtendedExpression(
89+
referred_expr=[
90+
stee.ExpressionReference(
91+
expression=stalg.Expression(
92+
selection=stalg.Expression.FieldReference(
93+
root_reference=stalg.Expression.FieldReference.RootReference(),
94+
direct_reference=stalg.Expression.ReferenceSegment(
95+
struct_field=stalg.Expression.ReferenceSegment.StructField(
96+
field=1
97+
)
98+
),
99+
)
100+
),
101+
output_names=["shop_details", "shop_id", "shop_total"],
79102
)
80103
],
81104
base_schema=nested_named_struct,

0 commit comments

Comments
 (0)