Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6893] Remove agg from Union children in IntersectToDistinctRule #4246

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,23 @@
*/
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;

import org.immutables.value.Value;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;

import static org.apache.calcite.util.Util.skipLast;

/**
* Planner rule that translates a distinct
Expand All @@ -40,30 +42,27 @@
* {@link org.apache.calcite.rel.core.Union},
* {@link org.apache.calcite.rel.core.Aggregate}, etc.
*
* <p>Rewrite: (GB-Union All-GB)-GB-UDTF (on all attributes)
*
* <h2>Example</h2>
*
* <p>Query: <code>R1 Intersect All R2</code>
*
* <p><code>R3 = GB(R1 on all attributes, count(*) as c)<br>
* union all<br>
* GB(R2 on all attributes, count(*) as c)</code>
*
* <p><code>R4 = GB(R3 on all attributes, count(c) as cnt, min(c) as m)</code>
*
* <p>Note that we do not need <code>min(c)</code> in intersect distinct.
*
* <p><code>R5 = Filter(cnt == #branch)</code>
*
* <p>If it is intersect all then
* <p>Original query:
* <pre>{@code
* SELECT job FROM "scott".emp WHERE deptno = 10
* INTERSECT
* SELECT job FROM "scott".emp WHERE deptno = 20
* }</pre>
*
* <p><code>R6 = UDTF (R5) which will explode the tuples based on min(c)<br>
* R7 = Project(R6 on all attributes)</code>
*
* <p>Else
*
* <p><code>R6 = Proj(R5 on all attributes)</code>
* <p>Query after conversion:
* <pre>{@code
* SELECT job
* FROM (
* SELECT job, 0 AS i FROM "scott".emp WHERE deptno = 10
* UNION ALL
* SELECT job, 1 AS i FROM "scott".emp WHERE deptno = 20
* )
* GROUP BY job
* HAVING COUNT(*) FILTER (WHERE i = 0) > 0
* AND COUNT(*) FILTER (WHERE i = 1) > 0
* }</pre>
*
* @see org.apache.calcite.rel.rules.UnionToDistinctRule
* @see CoreRules#INTERSECT_TO_DISTINCT
Expand Down Expand Up @@ -93,42 +92,41 @@ public IntersectToDistinctRule(Class<? extends Intersect> intersectClass,
if (intersect.all) {
return; // nothing we can do
}
final RelOptCluster cluster = intersect.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelBuilder relBuilder = call.builder();
final int oriFieldCount = intersect.getRowType().getFieldCount();
final int branchCount = intersect.getInputs().size();

// 1st level GB: create a GB (col0, col1, count() as c) for each branch
for (RelNode input : intersect.getInputs()) {
relBuilder.push(input);
relBuilder.aggregate(relBuilder.groupKey(relBuilder.fields()),
relBuilder.countStar(null));
List<AggCall> aggCalls = new ArrayList<>(branchCount);
for (int i = 0; i < branchCount; ++i) {
relBuilder.push(intersect.getInputs().get(i));
List<RexNode> fields = new ArrayList<>(relBuilder.fields());
fields.add(relBuilder.alias(relBuilder.literal(i), "i"));
relBuilder.project(fields);
aggCalls.add(
relBuilder.countStar(null).filter(
relBuilder.equals(relBuilder.field(oriFieldCount),
relBuilder.literal(i))).as("count_i" + i));
}

// create a union above all the branches
final int branchCount = intersect.getInputs().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i prefer the old name, branchCount, to inputCount.

relBuilder.union(true, branchCount);
final RelNode union = relBuilder.peek();

// 2nd level GB: create a GB (col0, col1, count(c)) for each branch
// the index of c is union.getRowType().getFieldList().size() - 1
final int fieldCount = union.getRowType().getFieldCount();

final ImmutableBitSet groupSet =
ImmutableBitSet.range(fieldCount - 1);
relBuilder.aggregate(relBuilder.groupKey(groupSet),
relBuilder.countStar(null));

// add a filter count(c) = #branches
relBuilder.filter(
relBuilder.equals(relBuilder.field(fieldCount - 1),
rexBuilder.makeBigintLiteral(new BigDecimal(branchCount))));
// Add aggCalls
relBuilder.aggregate(relBuilder.groupKey(ImmutableBitSet.range(fieldCount)), aggCalls);

// Project all but the last field
relBuilder.project(Util.skipLast(relBuilder.fields()));
// Generate filter count_i{n} > 0 for each branch
List<RexNode> filters = new ArrayList<>(branchCount);
for (int i = 0; i < branchCount; i++) {
filters.add(
relBuilder.greaterThan(relBuilder.field("count_i" + i),
relBuilder.literal(0)));
}

// the schema for intersect distinct is like this
// R3 on all attributes + count(c) as cnt
// finally add a project to project out the last column
relBuilder.filter(filters);
// Project all but the last added field (e.g. i and count_i{n})
relBuilder.project(skipLast(relBuilder.fields(), branchCount + 1));
call.transformTo(relBuilder.build());
}

Expand Down
54 changes: 28 additions & 26 deletions core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6008,21 +6008,22 @@ LogicalIntersect(all=[false])
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($9, 3)])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalUnion(all=[true])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 30)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalFilter(condition=[AND(>($10, 0), >($11, 0), >($12, 0))])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11], count_i2=[COUNT() FILTER $12])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)], $f12=[=($9, 2)])
LogicalUnion(all=[true])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[2])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 30)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down Expand Up @@ -6054,17 +6055,18 @@ LogicalIntersect(all=[true])
<![CDATA[
LogicalIntersect(all=[true])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($9, 2)])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalUnion(all=[true])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8}], agg#0=[COUNT()])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalFilter(condition=[AND(>($10, 0), >($11, 0))])
LogicalAggregate(group=[{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}], count_i0=[COUNT() FILTER $10], count_i1=[COUNT() FILTER $11])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[$9], $f10=[=($9, 0)], $f11=[=($9, 1)])
LogicalUnion(all=[true])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 10)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], i=[1])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 20)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
LogicalFilter(condition=[=($7, 30)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
Expand Down