diff --git a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java index 35d359bd8b7..a14e25e8482 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToDistinctRule.java @@ -16,21 +16,23 @@ */ package org.apache.calcite.rel.rules; -import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Intersect; import org.apache.calcite.rel.logical.LogicalIntersect; -import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilder.AggCall; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.calcite.util.Util; import org.immutables.value.Value; -import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.calcite.util.Util.skipLast; /** * Planner rule that translates a distinct @@ -40,30 +42,27 @@ * {@link org.apache.calcite.rel.core.Union}, * {@link org.apache.calcite.rel.core.Aggregate}, etc. * - *

Rewrite: (GB-Union All-GB)-GB-UDTF (on all attributes) - * *

Example

* - *

Query: R1 Intersect All R2 - * - *

R3 = GB(R1 on all attributes, count(*) as c)
- * union all
- * GB(R2 on all attributes, count(*) as c)
- * - *

R4 = GB(R3 on all attributes, count(c) as cnt, min(c) as m) - * - *

Note that we do not need min(c) in intersect distinct. - * - *

R5 = Filter(cnt == #branch) - * - *

If it is intersect all then + *

Original query: + *

{@code
+ * SELECT job FROM "scott".emp WHERE deptno = 10
+ * INTERSECT
+ * SELECT job FROM "scott".emp WHERE deptno = 20
+ * }
* - *

R6 = UDTF (R5) which will explode the tuples based on min(c)
- * R7 = Project(R6 on all attributes)
- * - *

Else - * - *

R6 = Proj(R5 on all attributes) + *

Query after conversion: + *

{@code
+ * SELECT job
+ * FROM (
+ *   SELECT job, 0 AS i FROM "scott".emp WHERE deptno = 10
+ *   UNION ALL
+ *   SELECT job, 1 AS i FROM "scott".emp WHERE deptno = 20
+ * )
+ * GROUP BY job
+ * HAVING COUNT(*) FILTER (WHERE i = 0) > 0
+ *    AND COUNT(*) FILTER (WHERE i = 1) > 0
+ * }
* * @see org.apache.calcite.rel.rules.UnionToDistinctRule * @see CoreRules#INTERSECT_TO_DISTINCT @@ -93,42 +92,41 @@ public IntersectToDistinctRule(Class intersectClass, if (intersect.all) { return; // nothing we can do } - final RelOptCluster cluster = intersect.getCluster(); - final RexBuilder rexBuilder = cluster.getRexBuilder(); final RelBuilder relBuilder = call.builder(); + final int oriFieldCount = intersect.getRowType().getFieldCount(); + final int branchCount = intersect.getInputs().size(); - // 1st level GB: create a GB (col0, col1, count() as c) for each branch - for (RelNode input : intersect.getInputs()) { - relBuilder.push(input); - relBuilder.aggregate(relBuilder.groupKey(relBuilder.fields()), - relBuilder.countStar(null)); + List aggCalls = new ArrayList<>(branchCount); + for (int i = 0; i < branchCount; ++i) { + relBuilder.push(intersect.getInputs().get(i)); + List fields = new ArrayList<>(relBuilder.fields()); + fields.add(relBuilder.alias(relBuilder.literal(i), "i")); + relBuilder.project(fields); + aggCalls.add( + relBuilder.countStar(null).filter( + relBuilder.equals(relBuilder.field(oriFieldCount), + relBuilder.literal(i))).as("count_i" + i)); } // create a union above all the branches - final int branchCount = intersect.getInputs().size(); relBuilder.union(true, branchCount); final RelNode union = relBuilder.peek(); - - // 2nd level GB: create a GB (col0, col1, count(c)) for each branch - // the index of c is union.getRowType().getFieldList().size() - 1 final int fieldCount = union.getRowType().getFieldCount(); - final ImmutableBitSet groupSet = - ImmutableBitSet.range(fieldCount - 1); - relBuilder.aggregate(relBuilder.groupKey(groupSet), - relBuilder.countStar(null)); - - // add a filter count(c) = #branches - relBuilder.filter( - relBuilder.equals(relBuilder.field(fieldCount - 1), - rexBuilder.makeBigintLiteral(new BigDecimal(branchCount)))); + // Add aggCalls + relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.range(fieldCount)), aggCalls); - // Project all but the last field - relBuilder.project(Util.skipLast(relBuilder.fields())); + // Generate filter count_i{n} > 0 for each branch + List filters = new ArrayList<>(branchCount); + for (int i = 0; i < branchCount; i++) { + filters.add( + relBuilder.greaterThan(relBuilder.field("count_i" + i), + relBuilder.literal(0))); + } - // the schema for intersect distinct is like this - // R3 on all attributes + count(c) as cnt - // finally add a project to project out the last column + relBuilder.filter(filters); + // Project all but the last added field (e.g. i and count_i{n}) + relBuilder.project(skipLast(relBuilder.fields(), branchCount + 1)); call.transformTo(relBuilder.build()); } diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index be8a8aeaadd..88a7c60bd63 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -6008,21 +6008,22 @@ LogicalIntersect(all=[false]) ($10, 0), >($11, 0), >($12, 0))]) + LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11], count_i2=[COUNT() FILTER $12]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)], $f12=[=($9, 2)]) + LogicalUnion(all=[true]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalFilter(condition=[=($7, 10)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalFilter(condition=[=($7, 20)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[2]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalFilter(condition=[=($7, 30)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> @@ -6054,17 +6055,18 @@ LogicalIntersect(all=[true]) ($10, 0), >($11, 0))]) + LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)]) + LogicalUnion(all=[true]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalFilter(condition=[=($7, 10)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1]) + LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) + LogicalFilter(condition=[=($7, 20)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8]) LogicalFilter(condition=[=($7, 30)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]])