This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 3047d7dd07 [fix](Nereids) fix or to in rule (#23940)
3047d7dd07 is described below

commit 3047d7dd078a38815866032030f9d8262fab927f
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Wed Sep 6 14:58:20 2023 +0800

    [fix](Nereids) fix or to in rule (#23940)
    
    or expression context can't propagation cross or expression.
    
    for example:
    ```
    select (a = 1 or a = 2 or a = 3) + (a = 4 or a = 5 or a = 6)
    = select a in [1, 2, 3] + a in [4,5,6]
    != select a in [1, 2, 3] + a in [1, 2, 3, 4, 5, 6]
    ```
---
 .../nereids/rules/expression/rules/OrToIn.java     | 55 +++++++++-------------
 .../doris/nereids/rules/rewrite/OrToInTest.java    | 27 +++++++++--
 2 files changed, 44 insertions(+), 38 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
index a54d5f5369..aaa077d199 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
@@ -19,13 +19,11 @@ package org.apache.doris.nereids.rules.expression.rules;
 
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
-import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext;
-import org.apache.doris.nereids.trees.expressions.And;
-import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
 import org.apache.doris.nereids.util.ExpressionUtils;
@@ -57,7 +55,7 @@ import java.util.Set;
  * adding any additional rule-specific fields to the default 
ExpressionRewriteContext. However, the entire expression
  * rewrite framework always passes an ExpressionRewriteContext of type context 
to all rules.
  */
-public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
+public class OrToIn extends 
DefaultExpressionRewriter<ExpressionRewriteContext> implements
         ExpressionRewriteRule<ExpressionRewriteContext> {
 
     public static final OrToIn INSTANCE = new OrToIn();
@@ -66,25 +64,20 @@ public class OrToIn extends 
DefaultExpressionRewriter<OrToInContext> implements
 
     @Override
     public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
-        return expr.accept(this, new OrToInContext());
+        return expr.accept(this, null);
     }
 
     @Override
-    public Expression visitCompoundPredicate(CompoundPredicate 
compoundPredicate, OrToInContext context) {
-        if (compoundPredicate instanceof And) {
-            return 
compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(),
-                            new OrToInContext()),
-                    compoundPredicate.child(1).accept(new OrToIn(),
-                            new OrToInContext()));
-        }
-        List<Expression> expressions = 
ExpressionUtils.extractDisjunction(compoundPredicate);
+    public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
+        Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
+        List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
         for (Expression expression : expressions) {
             if (expression instanceof EqualTo) {
-                addSlotToLiteralMap((EqualTo) expression, context);
+                addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral);
             }
         }
         List<Expression> rewrittenOr = new ArrayList<>();
-        for (Map.Entry<NamedExpression, Set<Literal>> entry : 
context.slotNameToLiteral.entrySet()) {
+        for (Map.Entry<NamedExpression, Set<Literal>> entry : 
slotNameToLiteral.entrySet()) {
             Set<Literal> literals = entry.getValue();
             if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
                 InPredicate inPredicate = new InPredicate(entry.getKey(), 
ImmutableList.copyOf(entry.getValue()));
@@ -92,26 +85,26 @@ public class OrToIn extends 
DefaultExpressionRewriter<OrToInContext> implements
             }
         }
         for (Expression expression : expressions) {
-            if (!ableToConvertToIn(expression, context)) {
-                rewrittenOr.add(expression);
+            if (!ableToConvertToIn(expression, slotNameToLiteral)) {
+                rewrittenOr.add(expression.accept(this, null));
             }
         }
 
         return ExpressionUtils.or(rewrittenOr);
     }
 
-    private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) {
+    private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression, 
Set<Literal>> slotNameToLiteral) {
         Expression left = equal.left();
         Expression right = equal.right();
         if (left instanceof NamedExpression && right instanceof Literal) {
-            addSlotToLiteral((NamedExpression) left, (Literal) right, context);
+            addSlotToLiteral((NamedExpression) left, (Literal) right, 
slotNameToLiteral);
         }
         if (right instanceof NamedExpression && left instanceof Literal) {
-            addSlotToLiteral((NamedExpression) right, (Literal) left, context);
+            addSlotToLiteral((NamedExpression) right, (Literal) left, 
slotNameToLiteral);
         }
     }
 
-    private boolean ableToConvertToIn(Expression expression, OrToInContext 
context) {
+    private boolean ableToConvertToIn(Expression expression, 
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
         if (!(expression instanceof EqualTo)) {
             return false;
         }
@@ -126,24 +119,18 @@ public class OrToIn extends 
DefaultExpressionRewriter<OrToInContext> implements
             namedExpression = (NamedExpression) right;
         }
         return namedExpression != null
-                && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, 
context)
+                && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, 
slotNameToLiteral)
                 >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD;
     }
 
-    public void addSlotToLiteral(NamedExpression namedExpression, Literal 
literal, OrToInContext context) {
-        Set<Literal> literals = 
context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new 
HashSet<>());
+    public void addSlotToLiteral(NamedExpression namedExpression, Literal 
literal,
+            Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
+        Set<Literal> literals = 
slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
         literals.add(literal);
     }
 
-    public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression 
namedExpression, OrToInContext context) {
-        return context.slotNameToLiteral.getOrDefault(namedExpression, 
Collections.emptySet()).size();
-    }
-
-    /**
-     * Context of OrToIn
-     */
-    public static class OrToInContext {
-        public final Map<NamedExpression, Set<Literal>> slotNameToLiteral = 
new HashMap<>();
-
+    public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression 
namedExpression,
+            Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
+        return slotNameToLiteral.getOrDefault(namedExpression, 
Collections.emptySet()).size();
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
index 651c330c55..f77a66dd88 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
@@ -33,10 +33,10 @@ import org.junit.jupiter.api.Test;
 import java.util.List;
 import java.util.Set;
 
-public class OrToInTest extends ExpressionRewriteTestHelper {
+class OrToInTest extends ExpressionRewriteTestHelper {
 
     @Test
-    public void test1() {
+    void test1() {
         String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
         Expression expression = PARSER.parseExpression(expr);
         Expression rewritten = new OrToIn().rewrite(expression, new 
ExpressionRewriteContext(null));
@@ -59,7 +59,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void test2() {
+    void test2() {
         String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
         Expression expression = PARSER.parseExpression(expr);
         Expression rewritten = new OrToIn().rewrite(expression, new 
ExpressionRewriteContext(null));
@@ -68,7 +68,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void test3() {
+    void test3() {
         String expr = "(col1 = 1 or col1 = 2) and  (col2 = 3 or col2 = 4)";
         Expression expression = PARSER.parseExpression(expr);
         Expression rewritten = new OrToIn().rewrite(expression, new 
ExpressionRewriteContext(null));
@@ -90,4 +90,23 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
         }
     }
 
+    @Test
+    void test4() {
+        String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+                + "         when col = 4 or col = 5 or col = 6 then 1 else 0 
end";
+        Expression expression = PARSER.parseExpression(expr);
+        Expression rewritten = new OrToIn().rewrite(expression, new 
ExpressionRewriteContext(null));
+        Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN 
(4, 5, 6) THEN 1 ELSE 0 END",
+                rewritten.toSql());
+    }
+
+    @Test
+    void test5() {
+        String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 
5))";
+        Expression expression = PARSER.parseExpression(expr);
+        Expression rewritten = new OrToIn().rewrite(expression, new 
ExpressionRewriteContext(null));
+        Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 
5)))",
+                rewritten.toSql());
+    }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to