This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 3047d7dd07 [fix](Nereids) fix or to in rule (#23940) 3047d7dd07 is described below commit 3047d7dd078a38815866032030f9d8262fab927f Author: 谢健 <jianx...@gmail.com> AuthorDate: Wed Sep 6 14:58:20 2023 +0800 [fix](Nereids) fix or to in rule (#23940) or expression context can't propagation cross or expression. for example: ``` select (a = 1 or a = 2 or a = 3) + (a = 4 or a = 5 or a = 6) = select a in [1, 2, 3] + a in [4,5,6] != select a in [1, 2, 3] + a in [1, 2, 3, 4, 5, 6] ``` --- .../nereids/rules/expression/rules/OrToIn.java | 55 +++++++++------------- .../doris/nereids/rules/rewrite/OrToInTest.java | 27 +++++++++-- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index a54d5f5369..aaa077d199 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -19,13 +19,11 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext; -import org.apache.doris.nereids.trees.expressions.And; -import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.util.ExpressionUtils; @@ -57,7 +55,7 @@ import java.util.Set; * adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression * rewrite framework always passes an ExpressionRewriteContext of type context to all rules. */ -public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements +public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements ExpressionRewriteRule<ExpressionRewriteContext> { public static final OrToIn INSTANCE = new OrToIn(); @@ -66,25 +64,20 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, new OrToInContext()); + return expr.accept(this, null); } @Override - public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, OrToInContext context) { - if (compoundPredicate instanceof And) { - return compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(), - new OrToInContext()), - compoundPredicate.child(1).accept(new OrToIn(), - new OrToInContext())); - } - List<Expression> expressions = ExpressionUtils.extractDisjunction(compoundPredicate); + public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>(); + List<Expression> expressions = ExpressionUtils.extractDisjunction(or); for (Expression expression : expressions) { if (expression instanceof EqualTo) { - addSlotToLiteralMap((EqualTo) expression, context); + addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral); } } List<Expression> rewrittenOr = new ArrayList<>(); - for (Map.Entry<NamedExpression, Set<Literal>> entry : context.slotNameToLiteral.entrySet()) { + for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) { Set<Literal> literals = entry.getValue(); if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { InPredicate inPredicate = new InPredicate(entry.getKey(), ImmutableList.copyOf(entry.getValue())); @@ -92,26 +85,26 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements } } for (Expression expression : expressions) { - if (!ableToConvertToIn(expression, context)) { - rewrittenOr.add(expression); + if (!ableToConvertToIn(expression, slotNameToLiteral)) { + rewrittenOr.add(expression.accept(this, null)); } } return ExpressionUtils.or(rewrittenOr); } - private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) { + private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral) { Expression left = equal.left(); Expression right = equal.right(); if (left instanceof NamedExpression && right instanceof Literal) { - addSlotToLiteral((NamedExpression) left, (Literal) right, context); + addSlotToLiteral((NamedExpression) left, (Literal) right, slotNameToLiteral); } if (right instanceof NamedExpression && left instanceof Literal) { - addSlotToLiteral((NamedExpression) right, (Literal) left, context); + addSlotToLiteral((NamedExpression) right, (Literal) left, slotNameToLiteral); } } - private boolean ableToConvertToIn(Expression expression, OrToInContext context) { + private boolean ableToConvertToIn(Expression expression, Map<NamedExpression, Set<Literal>> slotNameToLiteral) { if (!(expression instanceof EqualTo)) { return false; } @@ -126,24 +119,18 @@ public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements namedExpression = (NamedExpression) right; } return namedExpression != null - && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, context) + && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression, slotNameToLiteral) >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD; } - public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, OrToInContext context) { - Set<Literal> literals = context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>()); + public void addSlotToLiteral(NamedExpression namedExpression, Literal literal, + Map<NamedExpression, Set<Literal>> slotNameToLiteral) { + Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>()); literals.add(literal); } - public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, OrToInContext context) { - return context.slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size(); - } - - /** - * Context of OrToIn - */ - public static class OrToInContext { - public final Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>(); - + public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression namedExpression, + Map<NamedExpression, Set<Literal>> slotNameToLiteral) { + return slotNameToLiteral.getOrDefault(namedExpression, Collections.emptySet()).size(); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java index 651c330c55..f77a66dd88 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java @@ -33,10 +33,10 @@ import org.junit.jupiter.api.Test; import java.util.List; import java.util.Set; -public class OrToInTest extends ExpressionRewriteTestHelper { +class OrToInTest extends ExpressionRewriteTestHelper { @Test - public void test1() { + void test1() { String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); @@ -59,7 +59,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper { } @Test - public void test2() { + void test2() { String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); @@ -68,7 +68,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper { } @Test - public void test3() { + void test3() { String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)"; Expression expression = PARSER.parseExpression(expr); Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); @@ -90,4 +90,23 @@ public class OrToInTest extends ExpressionRewriteTestHelper { } } + @Test + void test4() { + String expr = "case when col = 1 or col = 2 or col = 3 then 1" + + " when col = 4 or col = 5 or col = 6 then 1 else 0 end"; + Expression expression = PARSER.parseExpression(expr); + Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END", + rewritten.toSql()); + } + + @Test + void test5() { + String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))"; + Expression expression = PARSER.parseExpression(expr); + Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))", + rewritten.toSql()); + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org