Skip to content

Commit ff1c139

Browse files
committed
equivalence classes: fix projection
This patch fixes the logic that projects equivalence classes: when run over the projection mapping to find new equivalent expressions, we need to normalize a source expression.
1 parent a93b4de commit ff1c139

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed

datafusion/physical-expr/src/equivalence/class.rs

+62-2
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,18 @@ impl EquivalenceGroup {
584584
.collect::<Vec<_>>();
585585
(new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
586586
});
587+
587588
// the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression.
588589
let mut new_classes: IndexMap<Arc<dyn PhysicalExpr>, EquivalenceClass> =
589590
IndexMap::new();
590591
mapping.iter().for_each(|(source, target)| {
592+
// We need to find equivalent projected expressions.
593+
// e.g. table with columns [a,b,c] and a == b, projection: [a+c, b+c].
594+
// To conclude that a + c == b + c we firsty normalize all source expressions
595+
// in the mapping, then merge all equivalent expressions into the classes.
596+
let normalized_expr = self.normalize_expr(Arc::clone(source));
591597
new_classes
592-
.entry(Arc::clone(source))
598+
.entry(normalized_expr)
593599
.or_insert_with(EquivalenceClass::new_empty)
594600
.push(Arc::clone(target));
595601
});
@@ -752,8 +758,9 @@ mod tests {
752758

753759
use super::*;
754760
use crate::equivalence::tests::create_test_params;
755-
use crate::expressions::{lit, BinaryExpr, Literal};
761+
use crate::expressions::{binary, col, lit, BinaryExpr, Literal};
756762

763+
use arrow_schema::{DataType, Field, Schema};
757764
use datafusion_common::{Result, ScalarValue};
758765
use datafusion_expr::Operator;
759766

@@ -1038,4 +1045,57 @@ mod tests {
10381045

10391046
Ok(())
10401047
}
1048+
1049+
#[test]
1050+
fn test_project_classes() -> Result<()> {
1051+
// - columns: [a, b, c].
1052+
// - "a" and "b" in the same equivalence class.
1053+
// - then after a+c, b+c projection col(0) and col(1) must be
1054+
// in the same class too.
1055+
let schema = Arc::new(Schema::new(vec![
1056+
Field::new("a", DataType::Int32, false),
1057+
Field::new("b", DataType::Int32, false),
1058+
Field::new("c", DataType::Int32, false),
1059+
]));
1060+
let mut group = EquivalenceGroup::empty();
1061+
group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);
1062+
1063+
let projected_schema = Arc::new(Schema::new(vec![
1064+
Field::new("a+c", DataType::Int32, false),
1065+
Field::new("b+c", DataType::Int32, false),
1066+
]));
1067+
1068+
let mapping = ProjectionMapping {
1069+
map: vec![
1070+
(
1071+
binary(
1072+
col("a", &schema)?,
1073+
Operator::Plus,
1074+
col("c", &schema)?,
1075+
&schema,
1076+
)?,
1077+
col("a+c", &projected_schema)?,
1078+
),
1079+
(
1080+
binary(
1081+
col("b", &schema)?,
1082+
Operator::Plus,
1083+
col("c", &schema)?,
1084+
&schema,
1085+
)?,
1086+
col("b+c", &projected_schema)?,
1087+
),
1088+
],
1089+
};
1090+
1091+
let projected = group.project(&mapping);
1092+
1093+
assert!(!projected.is_empty());
1094+
let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
1095+
let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);
1096+
1097+
assert!(first_normalized.eq(&second_normalized));
1098+
1099+
Ok(())
1100+
}
10411101
}

datafusion/sqllogictest/test_files/join.slt.part

+76
Original file line numberDiff line numberDiff line change
@@ -1312,3 +1312,79 @@ SELECT a+b*2,
13121312

13131313
statement ok
13141314
drop table t1;
1315+
1316+
# Test that equivalent classes are projected correctly.
1317+
1318+
statement ok
1319+
create table pairs(x int, y int) as values (1,1), (2,2), (3,3);
1320+
1321+
statement ok
1322+
create table f(a int) as values (1), (2), (3);
1323+
1324+
statement ok
1325+
create table s(b int) as values (1), (2), (3);
1326+
1327+
statement ok
1328+
set datafusion.optimizer.repartition_joins = true;
1329+
1330+
statement ok
1331+
set datafusion.execution.target_partitions = 16;
1332+
1333+
# After the filter applying (x = y) we can join by both x and y,
1334+
# partitioning only once.
1335+
1336+
query TT
1337+
explain
1338+
SELECT * FROM
1339+
(SELECT x+1 AS col0, y+1 AS col1 FROM PAIRS WHERE x == y)
1340+
JOIN f
1341+
ON col0 = f.a
1342+
JOIN s
1343+
ON col1 = s.b
1344+
----
1345+
logical_plan
1346+
01)Inner Join: col1 = CAST(s.b AS Int64)
1347+
02)--Inner Join: col0 = CAST(f.a AS Int64)
1348+
03)----Projection: CAST(pairs.x AS Int64) + Int64(1) AS col0, CAST(pairs.y AS Int64) + Int64(1) AS col1
1349+
04)------Filter: pairs.y = pairs.x
1350+
05)--------TableScan: pairs projection=[x, y]
1351+
06)----TableScan: f projection=[a]
1352+
07)--TableScan: s projection=[b]
1353+
physical_plan
1354+
01)CoalesceBatchesExec: target_batch_size=8192
1355+
02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col1@1, CAST(s.b AS Int64)@1)], projection=[col0@0, col1@1, a@2, b@3]
1356+
03)----ProjectionExec: expr=[col0@1 as col0, col1@2 as col1, a@0 as a]
1357+
04)------CoalesceBatchesExec: target_batch_size=8192
1358+
05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(CAST(f.a AS Int64)@1, col0@0)], projection=[a@0, col0@2, col1@3]
1359+
06)----------CoalesceBatchesExec: target_batch_size=8192
1360+
07)------------RepartitionExec: partitioning=Hash([CAST(f.a AS Int64)@1], 16), input_partitions=1
1361+
08)--------------ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as CAST(f.a AS Int64)]
1362+
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
1363+
10)----------CoalesceBatchesExec: target_batch_size=8192
1364+
11)------------RepartitionExec: partitioning=Hash([col0@0], 16), input_partitions=16
1365+
12)--------------ProjectionExec: expr=[CAST(x@0 AS Int64) + 1 as col0, CAST(y@1 AS Int64) + 1 as col1]
1366+
13)----------------RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1
1367+
14)------------------CoalesceBatchesExec: target_batch_size=8192
1368+
15)--------------------FilterExec: y@1 = x@0
1369+
16)----------------------MemoryExec: partitions=1, partition_sizes=[1]
1370+
17)----CoalesceBatchesExec: target_batch_size=8192
1371+
18)------RepartitionExec: partitioning=Hash([CAST(s.b AS Int64)@1], 16), input_partitions=1
1372+
19)--------ProjectionExec: expr=[b@0 as b, CAST(b@0 AS Int64) as CAST(s.b AS Int64)]
1373+
20)----------MemoryExec: partitions=1, partition_sizes=[1]
1374+
1375+
statement ok
1376+
drop table pairs;
1377+
1378+
statement ok
1379+
drop table f;
1380+
1381+
statement ok
1382+
drop table s;
1383+
1384+
# Reset the configs to old values.
1385+
statement ok
1386+
set datafusion.execution.target_partitions = 4;
1387+
1388+
statement ok
1389+
set datafusion.optimizer.repartition_joins = false;
1390+

0 commit comments

Comments
 (0)