diff --git a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java
index 35d359bd8b7..a14e25e8482 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java
@@ -16,21 +16,23 @@
*/
package org.apache.calcite.rel.rules;
-import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.logical.LogicalIntersect;
-import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.Util;
import org.immutables.value.Value;
-import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.calcite.util.Util.skipLast;
/**
* Planner rule that translates a distinct
@@ -40,30 +42,27 @@
* {@link org.apache.calcite.rel.core.Union},
* {@link org.apache.calcite.rel.core.Aggregate}, etc.
*
- *
Rewrite: (GB-Union All-GB)-GB-UDTF (on all attributes)
- *
*
Example
*
- * Query: R1 Intersect All R2
- *
- *
R3 = GB(R1 on all attributes, count(*) as c)
- * union all
- * GB(R2 on all attributes, count(*) as c)
- *
- *
R4 = GB(R3 on all attributes, count(c) as cnt, min(c) as m)
- *
- *
Note that we do not need min(c)
in intersect distinct.
- *
- *
R5 = Filter(cnt == #branch)
- *
- *
If it is intersect all then
+ *
Original query:
+ *
{@code
+ * SELECT job FROM "scott".emp WHERE deptno = 10
+ * INTERSECT
+ * SELECT job FROM "scott".emp WHERE deptno = 20
+ * }
*
- * R6 = UDTF (R5) which will explode the tuples based on min(c)
- * R7 = Project(R6 on all attributes)
- *
- *
Else
- *
- *
R6 = Proj(R5 on all attributes)
+ *
Query after conversion:
+ *
{@code
+ * SELECT job
+ * FROM (
+ * SELECT job, 0 AS i FROM "scott".emp WHERE deptno = 10
+ * UNION ALL
+ * SELECT job, 1 AS i FROM "scott".emp WHERE deptno = 20
+ * )
+ * GROUP BY job
+ * HAVING COUNT(*) FILTER (WHERE i = 0) > 0
+ * AND COUNT(*) FILTER (WHERE i = 1) > 0
+ * }
*
* @see org.apache.calcite.rel.rules.UnionToDistinctRule
* @see CoreRules#INTERSECT_TO_DISTINCT
@@ -93,42 +92,41 @@ public IntersectToDistinctRule(Class extends Intersect> intersectClass,
if (intersect.all) {
return; // nothing we can do
}
- final RelOptCluster cluster = intersect.getCluster();
- final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelBuilder relBuilder = call.builder();
+ final int oriFieldCount = intersect.getRowType().getFieldCount();
+ final int branchCount = intersect.getInputs().size();
- // 1st level GB: create a GB (col0, col1, count() as c) for each branch
- for (RelNode input : intersect.getInputs()) {
- relBuilder.push(input);
- relBuilder.aggregate(relBuilder.groupKey(relBuilder.fields()),
- relBuilder.countStar(null));
+ List aggCalls = new ArrayList<>(branchCount);
+ for (int i = 0; i < branchCount; ++i) {
+ relBuilder.push(intersect.getInputs().get(i));
+ List fields = new ArrayList<>(relBuilder.fields());
+ fields.add(relBuilder.alias(relBuilder.literal(i), "i"));
+ relBuilder.project(fields);
+ aggCalls.add(
+ relBuilder.countStar(null).filter(
+ relBuilder.equals(relBuilder.field(oriFieldCount),
+ relBuilder.literal(i))).as("count_i" + i));
}
// create a union above all the branches
- final int branchCount = intersect.getInputs().size();
relBuilder.union(true, branchCount);
final RelNode union = relBuilder.peek();
-
- // 2nd level GB: create a GB (col0, col1, count(c)) for each branch
- // the index of c is union.getRowType().getFieldList().size() - 1
final int fieldCount = union.getRowType().getFieldCount();
- final ImmutableBitSet groupSet =
- ImmutableBitSet.range(fieldCount - 1);
- relBuilder.aggregate(relBuilder.groupKey(groupSet),
- relBuilder.countStar(null));
-
- // add a filter count(c) = #branches
- relBuilder.filter(
- relBuilder.equals(relBuilder.field(fieldCount - 1),
- rexBuilder.makeBigintLiteral(new BigDecimal(branchCount))));
+ // Add aggCalls
+ relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.range(fieldCount)), aggCalls);
- // Project all but the last field
- relBuilder.project(Util.skipLast(relBuilder.fields()));
+ // Generate filter count_i{n} > 0 for each branch
+ List filters = new ArrayList<>(branchCount);
+ for (int i = 0; i < branchCount; i++) {
+ filters.add(
+ relBuilder.greaterThan(relBuilder.field("count_i" + i),
+ relBuilder.literal(0)));
+ }
- // the schema for intersect distinct is like this
- // R3 on all attributes + count(c) as cnt
- // finally add a project to project out the last column
+ relBuilder.filter(filters);
+ // Project all but the last added field (e.g. i and count_i{n})
+ relBuilder.project(skipLast(relBuilder.fields(), branchCount + 1));
call.transformTo(relBuilder.build());
}
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index be8a8aeaadd..88a7c60bd63 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -6008,21 +6008,22 @@ LogicalIntersect(all=[false])
($10, 0), >($11, 0), >($12, 0))])
+ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11], count_i2=[COUNT() FILTER $12])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)], $f12=[=($9, 2)])
+ LogicalUnion(all=[true])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalFilter(condition=[=($7, 20)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[2])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalFilter(condition=[=($7, 30)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
@@ -6054,17 +6055,18 @@ LogicalIntersect(all=[true])
($10, 0), >($11, 0))])
+ LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)])
+ LogicalUnion(all=[true])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalFilter(condition=[=($7, 10)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1])
+ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalFilter(condition=[=($7, 20)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 30)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])