This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d7e5f97b74 [feature](Nereids): eliminate AssertNumRows (#23842)
d7e5f97b74 is described below

commit d7e5f97b74a388af0aaab695b07e74901bbb1ae1
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Wed Sep 13 22:24:02 2023 +0800

    [feature](Nereids): eliminate AssertNumRows (#23842)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   4 +-
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../rules/rewrite/EliminateAssertNumRows.java      |  93 ++++++++++++++++++
 .../rules/rewrite/EliminateAssertNumRowsTest.java  | 106 +++++++++++++++++++++
 .../doris/nereids/util/LogicalPlanBuilder.java     |   9 ++
 .../data/nereids_tpch_shape_sf500_p0/shape/q11.out |  31 +++---
 .../data/nereids_tpch_shape_sf500_p0/shape/q15.out |  21 ++--
 .../data/nereids_tpch_shape_sf500_p0/shape/q22.out |  13 ++-
 8 files changed, 243 insertions(+), 35 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index de8e92eff3..5e0fd69cbc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -52,6 +52,7 @@ import 
org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar;
 import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
 import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
 import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
+import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows;
 import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
 import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
 import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
@@ -168,7 +169,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
                     bottomUp(
                             new EliminateLimit(),
                             new EliminateFilter(),
-                            new EliminateAggregate()
+                            new EliminateAggregate(),
+                            new EliminateAssertNumRows()
                     )
             ),
             // please note: this rule must run before NormalizeAggregate
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 8dc865482b..27399ec088 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -201,6 +201,7 @@ public enum RuleType {
     ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE),
     ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE),
     ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE),
+    ELIMINATE_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE),
     CONVERT_OUTER_JOIN_TO_ANTI(RuleTypeClass.REWRITE),
     FIND_HASH_CONDITION_FOR_JOIN(RuleTypeClass.REWRITE),
     MATERIALIZED_INDEX_AGG_SCAN(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
new file mode 100644
index 0000000000..84d459d30d
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRows.java
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
+import 
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+
+/** EliminateAssertNumRows */
+public class EliminateAssertNumRows extends OneRewriteRuleFactory {
+
+    @Override
+    public Rule build() {
+        return logicalAssertNumRows()
+                .then(assertNumRows -> {
+                    Plan checkPlan = assertNumRows.child();
+                    while (skipPlan(checkPlan) != checkPlan) {
+                        checkPlan = skipPlan(checkPlan);
+                    }
+                    return canEliminate(assertNumRows, checkPlan) ? 
assertNumRows.child() : null;
+                }).toRule(RuleType.ELIMINATE_ASSERT_NUM_ROWS);
+    }
+
+    private Plan skipPlan(Plan plan) {
+        if (plan instanceof LogicalProject || plan instanceof LogicalFilter || 
plan instanceof LogicalSort) {
+            plan = plan.child(0);
+        } else if (plan instanceof LogicalJoin) {
+            if (((LogicalJoin<?, ?>) 
plan).getJoinType().isLeftSemiOrAntiJoin()) {
+                plan = plan.child(0);
+            } else if (((LogicalJoin<?, ?>) 
plan).getJoinType().isRightSemiOrAntiJoin()) {
+                plan = plan.child(1);
+            }
+        }
+        return plan;
+    }
+
+    private boolean canEliminate(LogicalAssertNumRows<?> assertNumRows, Plan 
plan) {
+        long maxOutputRowcount;
+        // Don't need to consider TopN, because it's generated by Sort + Limit.
+        if (plan instanceof LogicalLimit) {
+            maxOutputRowcount = ((LogicalLimit<?>) plan).getLimit();
+        } else if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>) 
plan).getGroupByExpressions().isEmpty()) {
+            maxOutputRowcount = 1;
+        } else {
+            return false;
+        }
+
+        AssertNumRowsElement assertNumRowsElement = 
assertNumRows.getAssertNumRowsElement();
+        Assertion assertion = assertNumRowsElement.getAssertion();
+        long assertNum = assertNumRowsElement.getDesiredNumOfRows();
+
+        switch (assertion) {
+            case NE:
+            case LT:
+                if (maxOutputRowcount < assertNum) {
+                    return true;
+                }
+                break;
+            case LE:
+                if (maxOutputRowcount <= assertNum) {
+                    return true;
+                }
+                break;
+            default:
+                return false;
+        }
+        return false;
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
new file mode 100644
index 0000000000..f4b647a54e
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateAssertNumRowsTest.java
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import 
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+class EliminateAssertNumRowsTest implements MemoPatternMatchSupported {
+    @Test
+    void testFailedMatch() {
+        LogicalPlan plan = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+                .limit(10, 10)
+                .assertNumRows(Assertion.LT, 10)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateAssertNumRows())
+                .matchesFromRoot(
+                        logicalAssertNumRows(logicalLimit(logicalOlapScan()))
+                );
+    }
+
+    @Test
+    void testLimit() {
+        LogicalPlan plan = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+                .limit(10, 10)
+                .assertNumRows(Assertion.LE, 10)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateAssertNumRows())
+                .matchesFromRoot(
+                        logicalLimit(logicalOlapScan())
+                );
+    }
+
+    @Test
+    void testScalarAgg() {
+        LogicalPlan plan = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+                .agg(ImmutableList.of(), ImmutableList.of((new 
Count()).alias("cnt")))
+                .assertNumRows(Assertion.LE, 10)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateAssertNumRows())
+                .matchesFromRoot(
+                        logicalAggregate(logicalOlapScan())
+                );
+    }
+
+    @Test
+    void skipProject() {
+        LogicalPlan plan = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+                .limit(10, 10)
+                .project(ImmutableList.of(0, 1))
+                .project(ImmutableList.of(0, 1))
+                .assertNumRows(Assertion.LE, 10)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateAssertNumRows())
+                .matchesFromRoot(
+                        
logicalProject(logicalProject(logicalLimit(logicalOlapScan())))
+                );
+    }
+
+    @Test
+    void skipSemiJoin() {
+        LogicalPlan plan = new 
LogicalPlanBuilder(PlanConstructor.newLogicalOlapScan(0, "t1", 0))
+                .limit(10, 10)
+                .joinEmptyOn(PlanConstructor.newLogicalOlapScan(1, "t2", 0), 
JoinType.LEFT_SEMI_JOIN)
+                .assertNumRows(Assertion.LE, 10)
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyTopDown(new EliminateAssertNumRows())
+                .matchesFromRoot(
+                        leftSemiLogicalJoin(logicalLimit(logicalOlapScan()), 
logicalOlapScan())
+                );
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
index 99f7884b2b..c5024ff931 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java
@@ -20,6 +20,8 @@ package org.apache.doris.nereids.util;
 import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.properties.OrderKey;
 import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
+import 
org.apache.doris.nereids.trees.expressions.AssertNumRowsElement.Assertion;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
@@ -28,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.LimitPhase;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
@@ -192,4 +195,10 @@ public class LogicalPlanBuilder {
         LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, 
outputExprsList, this.plan);
         return from(agg);
     }
+
+    public LogicalPlanBuilder assertNumRows(Assertion assertion, long numRows) 
{
+        LogicalAssertNumRows<LogicalPlan> assertNumRows = new 
LogicalAssertNumRows<>(
+                new AssertNumRowsElement(numRows, "", assertion), this.plan);
+        return from(assertNumRows);
+    }
 }
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out 
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
index 8fd0e3341d..1d1e947be2 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q11.out
@@ -20,21 +20,20 @@ PhysicalResultSink
 --------------------------filter((nation.n_name = 'GERMANY'))
 ----------------------------PhysicalOlapScan[nation]
 ------------PhysicalDistribute
---------------PhysicalAssertNumRows
-----------------PhysicalProject
-------------------hashAgg[GLOBAL]
---------------------PhysicalDistribute
-----------------------hashAgg[LOCAL]
-------------------------PhysicalProject
---------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey = 
supplier.s_suppkey)
-----------------------------PhysicalProject
-------------------------------PhysicalOlapScan[partsupp]
-----------------------------PhysicalDistribute
-------------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = 
nation.n_nationkey)
+--------------PhysicalProject
+----------------hashAgg[GLOBAL]
+------------------PhysicalDistribute
+--------------------hashAgg[LOCAL]
+----------------------PhysicalProject
+------------------------hashJoin[INNER_JOIN](partsupp.ps_suppkey = 
supplier.s_suppkey)
+--------------------------PhysicalProject
+----------------------------PhysicalOlapScan[partsupp]
+--------------------------PhysicalDistribute
+----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = 
nation.n_nationkey)
+------------------------------PhysicalProject
+--------------------------------PhysicalOlapScan[supplier]
+------------------------------PhysicalDistribute
 --------------------------------PhysicalProject
-----------------------------------PhysicalOlapScan[supplier]
---------------------------------PhysicalDistribute
-----------------------------------PhysicalProject
-------------------------------------filter((nation.n_name = 'GERMANY'))
---------------------------------------PhysicalOlapScan[nation]
+----------------------------------filter((nation.n_name = 'GERMANY'))
+------------------------------------PhysicalOlapScan[nation]
 
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out 
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
index 4106594748..ff4350e080 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q15.out
@@ -17,15 +17,14 @@ PhysicalResultSink
 ------------------------filter((lineitem.l_shipdate >= 
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
 --------------------------PhysicalOlapScan[lineitem]
 ----------------PhysicalDistribute
-------------------PhysicalAssertNumRows
---------------------hashAgg[GLOBAL]
-----------------------PhysicalDistribute
-------------------------hashAgg[LOCAL]
---------------------------PhysicalProject
-----------------------------hashAgg[GLOBAL]
-------------------------------PhysicalDistribute
---------------------------------hashAgg[LOCAL]
-----------------------------------PhysicalProject
-------------------------------------filter((lineitem.l_shipdate >= 
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
---------------------------------------PhysicalOlapScan[lineitem]
+------------------hashAgg[GLOBAL]
+--------------------PhysicalDistribute
+----------------------hashAgg[LOCAL]
+------------------------PhysicalProject
+--------------------------hashAgg[GLOBAL]
+----------------------------PhysicalDistribute
+------------------------------hashAgg[LOCAL]
+--------------------------------PhysicalProject
+----------------------------------filter((lineitem.l_shipdate >= 
1996-01-01)(lineitem.l_shipdate < 1996-04-01))
+------------------------------------PhysicalOlapScan[lineitem]
 
diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out 
b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
index c352e2508a..7845eba2ba 100644
--- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
+++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q22.out
@@ -18,11 +18,10 @@ PhysicalResultSink
 ------------------------filter(substring(c_phone, 1, 2) IN ('13', '31', '23', 
'29', '30', '18', '17'))
 --------------------------PhysicalOlapScan[customer]
 ----------------------PhysicalDistribute
-------------------------PhysicalAssertNumRows
---------------------------hashAgg[GLOBAL]
-----------------------------PhysicalDistribute
-------------------------------hashAgg[LOCAL]
---------------------------------PhysicalProject
-----------------------------------filter((customer.c_acctbal > 
0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17'))
-------------------------------------PhysicalOlapScan[customer]
+------------------------hashAgg[GLOBAL]
+--------------------------PhysicalDistribute
+----------------------------hashAgg[LOCAL]
+------------------------------PhysicalProject
+--------------------------------filter((customer.c_acctbal > 
0.00)substring(c_phone, 1, 2) IN ('13', '31', '23', '29', '30', '18', '17'))
+----------------------------------PhysicalOlapScan[customer]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to