@@ -584,12 +584,18 @@ impl EquivalenceGroup {
584
584
. collect :: < Vec < _ > > ( ) ;
585
585
( new_class. len ( ) > 1 ) . then_some ( EquivalenceClass :: new ( new_class) )
586
586
} ) ;
587
+
587
588
// the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression.
588
589
let mut new_classes: IndexMap < Arc < dyn PhysicalExpr > , EquivalenceClass > =
589
590
IndexMap :: new ( ) ;
590
591
mapping. iter ( ) . for_each ( |( source, target) | {
592
+ // We need to find equivalent projected expressions.
593
+ // e.g. table with columns [a,b,c] and a == b, projection: [a+c, b+c].
594
+ // To conclude that a + c == b + c we firsty normalize all source expressions
595
+ // in the mapping, then merge all equivalent expressions into the classes.
596
+ let normalized_expr = self . normalize_expr ( Arc :: clone ( source) ) ;
591
597
new_classes
592
- . entry ( Arc :: clone ( source ) )
598
+ . entry ( normalized_expr )
593
599
. or_insert_with ( EquivalenceClass :: new_empty)
594
600
. push ( Arc :: clone ( target) ) ;
595
601
} ) ;
@@ -752,8 +758,9 @@ mod tests {
752
758
753
759
use super :: * ;
754
760
use crate :: equivalence:: tests:: create_test_params;
755
- use crate :: expressions:: { lit, BinaryExpr , Literal } ;
761
+ use crate :: expressions:: { binary , col , lit, BinaryExpr , Literal } ;
756
762
763
+ use arrow_schema:: { DataType , Field , Schema } ;
757
764
use datafusion_common:: { Result , ScalarValue } ;
758
765
use datafusion_expr:: Operator ;
759
766
@@ -1038,4 +1045,57 @@ mod tests {
1038
1045
1039
1046
Ok ( ( ) )
1040
1047
}
1048
+
1049
+ #[ test]
1050
+ fn test_project_classes ( ) -> Result < ( ) > {
1051
+ // - columns: [a, b, c].
1052
+ // - "a" and "b" in the same equivalence class.
1053
+ // - then after a+c, b+c projection col(0) and col(1) must be
1054
+ // in the same class too.
1055
+ let schema = Arc :: new ( Schema :: new ( vec ! [
1056
+ Field :: new( "a" , DataType :: Int32 , false ) ,
1057
+ Field :: new( "b" , DataType :: Int32 , false ) ,
1058
+ Field :: new( "c" , DataType :: Int32 , false ) ,
1059
+ ] ) ) ;
1060
+ let mut group = EquivalenceGroup :: empty ( ) ;
1061
+ group. add_equal_conditions ( & col ( "a" , & schema) ?, & col ( "b" , & schema) ?) ;
1062
+
1063
+ let projected_schema = Arc :: new ( Schema :: new ( vec ! [
1064
+ Field :: new( "a+c" , DataType :: Int32 , false ) ,
1065
+ Field :: new( "b+c" , DataType :: Int32 , false ) ,
1066
+ ] ) ) ;
1067
+
1068
+ let mapping = ProjectionMapping {
1069
+ map : vec ! [
1070
+ (
1071
+ binary(
1072
+ col( "a" , & schema) ?,
1073
+ Operator :: Plus ,
1074
+ col( "c" , & schema) ?,
1075
+ & schema,
1076
+ ) ?,
1077
+ col( "a+c" , & projected_schema) ?,
1078
+ ) ,
1079
+ (
1080
+ binary(
1081
+ col( "b" , & schema) ?,
1082
+ Operator :: Plus ,
1083
+ col( "c" , & schema) ?,
1084
+ & schema,
1085
+ ) ?,
1086
+ col( "b+c" , & projected_schema) ?,
1087
+ ) ,
1088
+ ] ,
1089
+ } ;
1090
+
1091
+ let projected = group. project ( & mapping) ;
1092
+
1093
+ assert ! ( !projected. is_empty( ) ) ;
1094
+ let first_normalized = projected. normalize_expr ( col ( "a+c" , & projected_schema) ?) ;
1095
+ let second_normalized = projected. normalize_expr ( col ( "b+c" , & projected_schema) ?) ;
1096
+
1097
+ assert ! ( first_normalized. eq( & second_normalized) ) ;
1098
+
1099
+ Ok ( ( ) )
1100
+ }
1041
1101
}
0 commit comments