This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new d7e5f97b74 [feature](Nereids): eliminate AssertNumRows (#23842) d7e5f97b74 is described below commit d7e5f97b74a388af0aaab695b07e74901bbb1ae1 Author: jakevin <jakevin...@gmail.com> AuthorDate: Wed Sep 13 22:24:02 2023 +0800 [feature](Nereids): eliminate AssertNumRows (#23842) --- .../doris/nereids/jobs/executor/Rewriter.java | 4 +- .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../rules/rewrite/EliminateAssertNumRows.java | 93 ++++++++++++++++++ .../rules/rewrite/EliminateAssertNumRowsTest.java | 106 +++++++++++++++++++++ .../doris/nereids/util/LogicalPlanBuilder.java | 9 ++ .../data/nereids_tpch_shape_sf500_p0/shape/q11.out | 31 +++--- .../data/nereids_tpch_shape_sf500_p0/shape/q15.out | 21 ++-- .../data/nereids_tpch_shape_sf500_p0/shape/q22.out | 13 ++- 8 files changed, 243 insertions(+), 35 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index de8e92eff3..5e0fd69cbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -52,6 +52,7 @@ import org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; +import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; @@ -168,7 +169,8 @@ public class Rewriter extends AbstractBatchJobExecutor { bottomUp( new EliminateLimit(), new EliminateFilter(), - new EliminateAggregate() + new EliminateAggregate(), + new EliminateAssertNumRows() ) ), // please note: this rule must run before NormalizeAggregate diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 8dc865482b..27399ec088 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -201,6 +201,7 @@ public enum RuleType { ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE), ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE), ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE), + ELIMINATE_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE), CONVERT_OUTER_JOIN_TO_ANTI(RuleTypeClass.REWRITE), FIND_HASH_CONDITION_FOR_JOIN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_AGG_SCAN(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java new file mode 100644 index 0000000000..84d459d30d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; +import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; + +/** EliminateAssertNumRows */ +public class EliminateAssertNumRows extends OneRewriteRuleFactory { + + @Override + public Rule build() { + return logicalAssertNumRows() + .then(assertNumRows -> { + Plan checkPlan = assertNumRows.child(); + while (skipPlan(checkPlan) != checkPlan) { + checkPlan = skipPlan(checkPlan); + } + return canEliminate(assertNumRows, checkPlan) ? assertNumRows.child() : null; + }).toRule(RuleType.ELIMINATE_ASSERT_NUM_ROWS); + } + + private Plan skipPlan(Plan plan) { + if (plan instanceof LogicalProject || plan instanceof LogicalFilter || plan instanceof LogicalSort) { + plan = plan.child(0); + } else if (plan instanceof LogicalJoin) { + if (((LogicalJoin<?, ?>) plan).getJoinType().isLeftSemiOrAntiJoin()) { + plan = plan.child(0); + } else if (((LogicalJoin<?, ?>) plan).getJoinType().isRightSemiOrAntiJoin()) { + plan = plan.child(1); + } + } + return plan; + } + + private boolean canEliminate(LogicalAssertNumRows<?> assertNumRows, Plan plan) { + long maxOutputRowcount; + // Don't need to consider TopN, because it's generated by Sort + Limit. + if (plan instanceof LogicalLimit) { + maxOutputRowcount = ((LogicalLimit<?>) plan).getLimit(); + } else if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>) plan).getGroupByExpressions().isEmpty()) { + maxOutputRowcount = 1; + } else { + return false; + } + + AssertNumRowsElement assertNumRowsElement = assertNumRows.getAssertNumRowsElement(); + Assertion assertion = assertNumRowsElement.getAssertion(); + long assertNum = assertNumRowsElement.getDesiredNumOfRows(); + + switch (assertion) { + case NE: + case LT: + if (maxOutputRowcount < assertNum) { + return true; + } + break; + case LE: + if (maxOutputRowcount <= assertNum) { + return true; + } + break; + default: + return false; + } + return false; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java new file mode 100644 index 0000000000..f4b647a54e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class EliminateAssertNumRowsTest implements MemoPatternMatchSupported { + @Test + void testFailedMatch() { + LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) + .limit(10, 10) + .assertNumRows(Assertion.LT, 10) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateAssertNumRows()) + .matchesFromRoot( + logicalAssertNumRows(logicalLimit(logicalOlapScan())) + ); + } + + @Test + void testLimit() { + LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) + .limit(10, 10) + .assertNumRows(Assertion.LE, 10) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateAssertNumRows()) + .matchesFromRoot( + logicalLimit(logicalOlapScan()) + ); + } + + @Test + void testScalarAgg() { + LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) + .agg(ImmutableList.of(), ImmutableList.of((new Count()).alias("cnt"))) + .assertNumRows(Assertion.LE, 10) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateAssertNumRows()) + .matchesFromRoot( + logicalAggregate(logicalOlapScan()) + ); + } + + @Test + void skipProject() { + LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) + .limit(10, 10) + .project(ImmutableList.of(0, 1)) + .project(ImmutableList.of(0, 1)) + .assertNumRows(Assertion.LE, 10) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateAssertNumRows()) + .matchesFromRoot( + logicalProject(logicalProject(logicalLimit(logicalOlapScan()))) + ); + } + + @Test + void skipSemiJoin() { + LogicalPlan plan = new LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0)) + .limit(10, 10) + .joinEmptyOn(PlanConstructor.newLogicalOlapScan(1, "t2", 0), JoinType.LEFT_SEMI_JOIN) + .assertNumRows(Assertion.LE, 10) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyTopDown(new EliminateAssertNumRows()) + .matchesFromRoot( + leftSemiLogicalJoin(logicalLimit(logicalOlapScan()), logicalOlapScan()) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index 99f7884b2b..c5024ff931 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -20,6 +20,8 @@ package org.apache.doris.nereids.util; import org.apache.doris.common.Pair; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; +import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -28,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.LimitPhase; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -192,4 +195,10 @@ public class LogicalPlanBuilder { LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan); return from(agg); } + + public LogicalPlanBuilder assertNumRows(Assertion assertion, long numRows) { + LogicalAssertNumRows<LogicalPlan> assertNumRows = new LogicalAssertNumRows<>( + new AssertNumRowsElement(numRows, "", assertion), this.plan); + return from(assertNumRows); + } } diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out index 8fd0e3341d..1d1e947be2 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out @@ -20,21 +20,20 @@ PhysicalResultSink --------------------------filter((nation.n_name = 'GERMANY')) ----------------------------PhysicalOlapScan[nation] ------------PhysicalDistribute ---------------PhysicalAssertNumRows -----------------PhysicalProject -------------------hashAgg[GLOBAL] ---------------------PhysicalDistribute -----------------------hashAgg[LOCAL] -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[partsupp] -----------------------------PhysicalDistribute -------------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) +--------------PhysicalProject +----------------hashAgg[GLOBAL] +------------------PhysicalDistribute +--------------------hashAgg[LOCAL] +----------------------PhysicalProject +------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey = supplier.s_suppkey) +--------------------------PhysicalProject +----------------------------PhysicalOlapScan[partsupp] +--------------------------PhysicalDistribute +----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[supplier] +------------------------------PhysicalDistribute --------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[supplier] ---------------------------------PhysicalDistribute -----------------------------------PhysicalProject -------------------------------------filter((nation.n_name = 'GERMANY')) ---------------------------------------PhysicalOlapScan[nation] +----------------------------------filter((nation.n_name = 'GERMANY')) +------------------------------------PhysicalOlapScan[nation] diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out index 4106594748..ff4350e080 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out @@ -17,15 +17,14 @@ PhysicalResultSink ------------------------filter((lineitem.l_shipdate >= 1996-01-01)(lineitem.l_shipdate < 1996-04-01)) --------------------------PhysicalOlapScan[lineitem] ----------------PhysicalDistribute -------------------PhysicalAssertNumRows ---------------------hashAgg[GLOBAL] -----------------------PhysicalDistribute -------------------------hashAgg[LOCAL] ---------------------------PhysicalProject -----------------------------hashAgg[GLOBAL] -------------------------------PhysicalDistribute ---------------------------------hashAgg[LOCAL] -----------------------------------PhysicalProject -------------------------------------filter((lineitem.l_shipdate >= 1996-01-01)(lineitem.l_shipdate < 1996-04-01)) ---------------------------------------PhysicalOlapScan[lineitem] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute +----------------------hashAgg[LOCAL] +------------------------PhysicalProject +--------------------------hashAgg[GLOBAL] +----------------------------PhysicalDistribute +------------------------------hashAgg[LOCAL] +--------------------------------PhysicalProject +----------------------------------filter((lineitem.l_shipdate >= 1996-01-01)(lineitem.l_shipdate < 1996-04-01)) +------------------------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out index c352e2508a..7845eba2ba 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out @@ -18,11 +18,10 @@ PhysicalResultSink ------------------------filter(substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')) --------------------------PhysicalOlapScan[customer] ----------------------PhysicalDistribute -------------------------PhysicalAssertNumRows ---------------------------hashAgg[GLOBAL] -----------------------------PhysicalDistribute -------------------------------hashAgg[LOCAL] ---------------------------------PhysicalProject -----------------------------------filter((customer.c_acctbal > 0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')) -------------------------------------PhysicalOlapScan[customer] +------------------------hashAgg[GLOBAL] +--------------------------PhysicalDistribute +----------------------------hashAgg[LOCAL] +------------------------------PhysicalProject +--------------------------------filter((customer.c_acctbal > 0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17')) +----------------------------------PhysicalOlapScan[customer] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org