This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git

commit e2a54f6408ae07a6b85da3206479a8602c00e7a2
Author: Henry2SS <45096548+henry...@users.noreply.github.com>
AuthorDate: Thu Dec 29 14:50:32 2022 +0800

    [enhancement](session var) varariable to control whether to rewrite OR to 
IN or not (#15437)
---
 .../java/org/apache/doris/qe/SessionVariable.java  | 14 ++++++
 .../doris/rewrite/ExtractCommonFactorsRule.java    | 51 +++++++++++++++++-----
 .../org/apache/doris/planner/QueryPlanTest.java    | 17 ++++++++
 .../data/performance_p0/redundant_conjuncts.out    |  2 +-
 .../performance_p0/redundant_conjuncts.groovy      |  1 +
 5 files changed, 74 insertions(+), 11 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index f0a556067f..8d8f412aee 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -196,6 +196,9 @@ public class SessionVariable implements Serializable, 
Writable {
 
     //percentage of EXEC_MEM_LIMIT
     public static final String BROADCAST_HASHTABLE_MEM_LIMIT_PERCENTAGE = 
"broadcast_hashtable_mem_limit_percentage";
+
+    public static final String REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 
"rewrite_or_to_in_predicate_threshold";
+
     public static final String NEREIDS_STAR_SCHEMA_SUPPORT = 
"nereids_star_schema_support";
 
     public static final String NEREIDS_CBO_PENALTY_FACTOR = 
"nereids_cbo_penalty_factor";
@@ -544,6 +547,9 @@ public class SessionVariable implements Serializable, 
Writable {
     @VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
     private boolean nereidsStarSchemaSupport = true;
 
+    @VariableMgr.VarAttr(name = REWRITE_OR_TO_IN_PREDICATE_THRESHOLD)
+    private int rewriteOrToInPredicateThreshold = 2;
+
     @VariableMgr.VarAttr(name = NEREIDS_CBO_PENALTY_FACTOR)
     private double nereidsCboPenaltyFactor = 0.7;
     @VariableMgr.VarAttr(name = ENABLE_NEREIDS_TRACE)
@@ -661,6 +667,14 @@ public class SessionVariable implements Serializable, 
Writable {
         this.blockEncryptionMode = blockEncryptionMode;
     }
 
+    public void setRewriteOrToInPredicateThreshold(int threshold) {
+        this.rewriteOrToInPredicateThreshold = threshold;
+    }
+
+    public int getRewriteOrToInPredicateThreshold() {
+        return rewriteOrToInPredicateThreshold;
+    }
+
     public long getMaxExecMemByte() {
         return maxExecMemByte;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
index 5a3bc34c8c..3c808c6100 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java
@@ -28,6 +28,7 @@ import org.apache.doris.analysis.SlotRef;
 import org.apache.doris.analysis.TableName;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.planner.PlanNode;
+import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.rewrite.ExprRewriter.ClauseType;
 
 import com.google.common.base.Preconditions;
@@ -43,6 +44,7 @@ import org.apache.logging.log4j.Logger;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
@@ -462,6 +464,13 @@ public class ExtractCommonFactorsRule implements 
ExprRewriteRule {
         boolean isOrToInAllowed = true;
         Set<String> slotSet = new LinkedHashSet<>();
 
+        int rewriteThreshold;
+        if (ConnectContext.get() == null) {
+            rewriteThreshold = 2;
+        } else {
+            rewriteThreshold = 
ConnectContext.get().getSessionVariable().getRewriteOrToInPredicateThreshold();
+        }
+
         for (int i = 0; i < exprs.size(); i++) {
             Expr predicate = exprs.get(i);
             if (!(predicate instanceof BinaryPredicate) && !(predicate 
instanceof InPredicate)) {
@@ -492,22 +501,44 @@ public class ExtractCommonFactorsRule implements 
ExprRewriteRule {
         // isOrToInAllowed : true, means can rewrite
         // slotSet.size : nums of columnName in exprs, should be 1
         if (isOrToInAllowed && slotSet.size() == 1) {
-            // slotRef to get ColumnName
-
-            // SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0);
-            List<Expr> childrenList = exprs.get(0).getChildren();
-            inPredicate = new InPredicate(exprs.get(0).getChild(0),
-                    childrenList.subList(1, childrenList.size()), false);
-
-            for (int i = 1; i < exprs.size(); i++) {
-                childrenList = exprs.get(i).getChildren();
-                inPredicate.addChildren(childrenList.subList(1, 
childrenList.size()));
+            if (exprs.size() < rewriteThreshold) {
+                return null;
             }
+
+            // get deduplication list
+            List<Expr> deduplicationExprs = getDeduplicationList(exprs);
+            inPredicate = new InPredicate(deduplicationExprs.get(0),
+                    deduplicationExprs.subList(1, deduplicationExprs.size()), 
false);
         }
 
         return inPredicate;
     }
 
+    public List<Expr> getDeduplicationList(List<Expr> exprs) {
+        Set<Expr> set = new HashSet<>();
+        List<Expr> deduplicationExprList = new ArrayList<>();
+
+        deduplicationExprList.add(exprs.get(0).getChild(0));
+
+        for (Expr expr : exprs) {
+            if (expr instanceof BinaryPredicate) {
+                if (!set.contains(expr.getChild(1))) {
+                    set.add(expr.getChild(1));
+                    deduplicationExprList.add(expr.getChild(1));
+                }
+            } else {
+                List<Expr> childrenExprs = expr.getChildren();
+                for (Expr childrenExpr : childrenExprs.subList(1, 
childrenExprs.size())) {
+                    if (!set.contains(childrenExpr)) {
+                        set.add(childrenExpr);
+                        deduplicationExprList.add(childrenExpr);
+                    }
+                }
+            }
+        }
+        return deduplicationExprList;
+    }
+
     /**
      * Convert RangeSet to Compound Predicate
      * @param slotRef: <k1>
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
index 8d1ad8b78c..914bc338bc 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
@@ -2238,5 +2238,22 @@ public class QueryPlanTest extends TestWithFeService {
         sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) 
and (scan_bytes = 2 or scan_bytes = 3)";
         explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, 
"EXPLAIN " + sql);
         Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN 
(1, 2), `scan_bytes` IN (2, 3)"));
+
+        sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or 
query_time = 3 or query_time = 1";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, 
"EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN 
(1, 2, 3)"));
+
+        sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or 
query_time in (3, 2)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, 
"EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN 
(1, 2, 3)"));
+
+        
connectContext.getSessionVariable().setRewriteOrToInPredicateThreshold(100);
+        sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or 
query_time in (3, 4)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, 
"EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 
1 OR `query_time` = 2 OR `query_time` IN (3, 4))"));
+
+        sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) 
and query_time in (3, 4)";
+        explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, 
"EXPLAIN " + sql);
+        Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 
1 OR `query_time` = 2), `query_time` IN (3, 4)"));
     }
 }
diff --git a/regression-test/data/performance_p0/redundant_conjuncts.out 
b/regression-test/data/performance_p0/redundant_conjuncts.out
index 7dbabccf37..98178f31aa 100644
--- a/regression-test/data/performance_p0/redundant_conjuncts.out
+++ b/regression-test/data/performance_p0/redundant_conjuncts.out
@@ -23,7 +23,7 @@ PLAN FRAGMENT 0
 
   0:VOlapScanNode
      TABLE: 
default_cluster:regression_test_performance_p0.redundant_conjuncts(redundant_conjuncts),
 PREAGGREGATION: OFF. Reason: No AggregateInfo
-     PREDICATES: `k1` IN (1, 2)
+     PREDICATES: (`k1` = 1 OR `k1` = 2)
      partitions=0/1, tablets=0/0, tabletList=
      cardinality=0, avgRowSize=8.0, numNodes=1
 
diff --git a/regression-test/suites/performance_p0/redundant_conjuncts.groovy 
b/regression-test/suites/performance_p0/redundant_conjuncts.groovy
index 14624a8049..c9ed28b026 100644
--- a/regression-test/suites/performance_p0/redundant_conjuncts.groovy
+++ b/regression-test/suites/performance_p0/redundant_conjuncts.groovy
@@ -39,6 +39,7 @@ suite("redundant_conjuncts") {
     EXPLAIN SELECT v1 FROM redundant_conjuncts WHERE k1 = 1 AND k1 = 1;
     """
 
+    sql "set REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 100"
     qt_redundant_conjuncts_gnerated_by_extract_common_filter """
     EXPLAIN SELECT v1 FROM redundant_conjuncts WHERE k1 = 1 OR k1 = 2;
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to