This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 38a62d890fd [opt](nereids) simplify arithmethic handle with mix 
add/sub/multiply/divide (#45543)
38a62d890fd is described below

commit 38a62d890fde29d15cac04ce70326baeba0813af
Author: yujun <yu...@selectdb.com>
AuthorDate: Fri Jan 3 14:16:11 2025 +0800

    [opt](nereids) simplify arithmethic handle with mix add/sub/multiply/divide 
(#45543)
    
    ### What problem does this PR solve?
    
    Two optimizations:
    
    1. handle mix add / sub / multiply / divide
    
    SimplifyArithmeticRule only handle add-sub, or multiply-divide, but not
    both of them.
    
    for example, if the expression root is add, then only simplify add-sub,
    but not simplify multiply-divide.
    
    for expr a + 10 + (b * 2 * 3 * (c + 4 + 5)) + 20, after fold const
    and this rule, it will opt as a + (b * (c + 4 + 5) * 2 * 3) + 30, but
    after this pr it will opt as a + (b * (c+9) * 6) + 30
    
    2. handle cast
    
    SimplifyArithmeticRule not handle with cast.
    
    for example, for expr cast ( a * 2 * 30 as double) / (cast 10 as
    double) , after fold const and this rule, it will opt as cast (a * 60
    as double) / 10.0, but after this pr it will opt as cast (a as double)
    * 6.0
---
 .../expression/rules/SimplifyArithmeticRule.java   | 86 +++++++++++++++-------
 .../expression/SimplifyArithmeticRuleTest.java     | 27 +++++--
 2 files changed, 80 insertions(+), 33 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
index 44d6505b003..6eea495e5cf 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
@@ -22,10 +22,12 @@ import 
org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
 import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Multiply;
 import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.types.DataType;
 import org.apache.doris.nereids.util.TypeCoercionUtils;
 import org.apache.doris.nereids.util.TypeUtils;
 import org.apache.doris.nereids.util.Utils;
@@ -35,6 +37,7 @@ import com.google.common.collect.Lists;
 
 import java.util.List;
 import java.util.Optional;
+import java.util.function.Predicate;
 
 /**
  * Simplify arithmetic rule.
@@ -91,6 +94,9 @@ public class SimplifyArithmeticRule implements 
ExpressionPatternRuleFactory {
         }
         // 2. move variables to left side and move constants to right sid.
         for (Operand operand : flattedExpressions) {
+            if (operand.expression instanceof BinaryArithmetic) {
+                operand.expression = simplify((BinaryArithmetic) 
operand.expression);
+            }
             if (operand.expression.isConstant()) {
                 constants.add(operand);
             } else {
@@ -129,45 +135,73 @@ public class SimplifyArithmeticRule implements 
ExpressionPatternRuleFactory {
         }
     }
 
+    // isAddOrSub: true for extract only "+" or "-" sub expressions, false for 
extract only "*" or "/" sub expressions
     private static List<Operand> flatten(Expression expr, boolean isAddOrSub) {
         List<Operand> result = Lists.newArrayList();
-        if (isAddOrSub) {
-            flattenAddSubtract(true, expr, result);
-        } else {
-            flattenMultiplyDivide(true, expr, result);
-        }
+        doFlatten(true, expr, isAddOrSub, result, Optional.empty());
         return result;
     }
 
-    private static void flattenAddSubtract(boolean flag, Expression expr, 
List<Operand> result) {
-        if (TypeUtils.isAddOrSubtract(expr)) {
-            BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
-            flattenAddSubtract(flag, arithmetic.left(), result);
-            if (TypeUtils.isSubtract(expr) && !flag) {
-                flattenAddSubtract(true, arithmetic.right(), result);
-            } else if (TypeUtils.isAdd(expr) && !flag) {
-                flattenAddSubtract(false, arithmetic.right(), result);
+    // flag: true for '+' or '*', false for '-' or '/'
+    // isAddOrSub: true for extract only "+" or "-" sub expressions, false for 
extract only "*" or "/" sub expressions
+    private static void doFlatten(boolean flag, Expression expr, boolean 
isAddOrSub, List<Operand> result,
+            Optional<DataType> castType) {
+        // cast (a * 10 as double)  *  (cast 20 as double)
+        // => cast(a as double) * (cast 10 as double) * (cast 20 as double)
+        BinaryArithmetic arithmetic = null;
+        Predicate<Expression> isPositiveArithmetic = isAddOrSub
+                ? TypeUtils::isAdd : TypeUtils::isMultiply;
+        Predicate<Expression> isNegativeArithmetic = isAddOrSub
+                ? TypeUtils::isSubtract : TypeUtils::isDivide;
+        Predicate<Expression> isPosNegArithmetic = 
isPositiveArithmetic.or(isNegativeArithmetic);
+        if (isPosNegArithmetic.test(expr)) {
+            arithmetic = (BinaryArithmetic) expr;
+        } else if (expr instanceof Cast && hasConstantOperand(expr, 
isAddOrSub)) {
+            Cast cast = (Cast) expr;
+            if (isPosNegArithmetic.test(cast.child())) {
+                arithmetic = (BinaryArithmetic) cast.child();
+                castType = Optional.of(cast.getDataType());
+            }
+        }
+        if (arithmetic != null) {
+            doFlatten(flag, arithmetic.left(), isAddOrSub, result, castType);
+            if (isNegativeArithmetic.test(arithmetic) && !flag) {
+                doFlatten(true, arithmetic.right(), isAddOrSub, result, 
castType);
+            } else if (isPositiveArithmetic.test(arithmetic) && !flag) {
+                doFlatten(false, arithmetic.right(), isAddOrSub, result, 
castType);
             } else {
-                flattenAddSubtract(!TypeUtils.isSubtract(expr), 
arithmetic.right(), result);
+                doFlatten(!isNegativeArithmetic.test(arithmetic), 
arithmetic.right(), isAddOrSub, result, castType);
             }
         } else {
-            result.add(Operand.of(flag, expr));
+            if (castType.isPresent()) {
+                result.add(Operand.of(flag, 
TypeCoercionUtils.castIfNotSameType(expr, castType.get())));
+            } else {
+                result.add(Operand.of(flag, expr));
+            }
         }
     }
 
-    private static void flattenMultiplyDivide(boolean flag, Expression expr, 
List<Operand> result) {
-        if (TypeUtils.isMultiplyOrDivide(expr)) {
-            BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
-            flattenMultiplyDivide(flag, arithmetic.left(), result);
-            if (TypeUtils.isDivide(expr) && !flag) {
-                flattenMultiplyDivide(true, arithmetic.right(), result);
-            } else if (TypeUtils.isMultiply(expr) && !flag) {
-                flattenMultiplyDivide(false, arithmetic.right(), result);
-            } else {
-                flattenMultiplyDivide(!TypeUtils.isDivide(expr), 
arithmetic.right(), result);
+    private static boolean hasConstantOperand(Expression expr, boolean 
isAddOrSub) {
+        if (expr.isConstant()) {
+            return true;
+        }
+
+        Predicate<Expression> checkArithmetic = isAddOrSub
+                ? TypeUtils::isAddOrSubtract : TypeUtils::isMultiplyOrDivide;
+        BinaryArithmetic arithmetic = null;
+        if (checkArithmetic.test(expr)) {
+            arithmetic = (BinaryArithmetic) expr;
+        } else if (expr instanceof Cast) {
+            Cast cast = (Cast) expr;
+            if (checkArithmetic.test(cast.child())) {
+                arithmetic = (BinaryArithmetic) cast.child();
             }
+        }
+        if (arithmetic != null) {
+            return hasConstantOperand(arithmetic.left(), isAddOrSub)
+                    || hasConstantOperand(arithmetic.right(), isAddOrSub);
         } else {
-            result.add(Operand.of(flag, expr));
+            return false;
         }
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
index f23aefe5267..92eb90e93b0 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
@@ -46,12 +46,19 @@ class SimplifyArithmeticRuleTest extends 
ExpressionRewriteTestHelper {
         assertRewriteAfterTypeCoercion("IA + 2 - ((1 - IB) - (3 + IC))", "IA + 
IB + IC + 4");
         assertRewriteAfterTypeCoercion("IA * IB + 2 - IC * 2", "(IA * IB) - 
(IC * 2) + 2");
         assertRewriteAfterTypeCoercion("IA * IB", "IA * IB");
+
         assertRewriteAfterTypeCoercion("IA * IB / 2 * 2", "cast((IA * IB) as 
DOUBLE) / 1.0");
         assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as 
DOUBLE) / 4.0");
         assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as 
DOUBLE) / 4.0");
         assertRewriteAfterTypeCoercion("IA * (IB / 2) * 2)", "cast(IA as 
DOUBLE) * cast(IB as DOUBLE) / 1.0");
         assertRewriteAfterTypeCoercion("IA * (IB / 2) * (IC + 1))", "cast(IA 
as DOUBLE) * cast(IB as DOUBLE) * cast((IC + 1) as DOUBLE) / 2.0");
         assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4", 
"(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) / 
4.0)");
+        assertRewriteAfterTypeCoercion("-1 + (10 - 20)  * (3 - 6) - (100 - 
200) * (6 - 3)", "329");
+        assertRewriteAfterTypeCoercion("IA - 10 + (IB * 2 * 3) + 20", "IA + 
(IB * 6) - (-10)");
+        assertRewriteAfterTypeCoercion("IA / 10 * (IB - 2 + 3) * 20", 
"((cast(IA as DOUBLE) * cast((IB - (-1)) as DOUBLE)) / 0.5)");
+        assertRewriteAfterTypeCoercion("1 + ((IA * 2 * 3) * 10 / 10)", 
"((cast(IA as DOUBLE) * 6.0) + 1.0)");
+        assertRewriteAfterTypeCoercion("1 + (IA * 2 * 20 / (IB + 5 + (IC * 10 
* 20 / 50 + 5 + 6) + 20) / 20) * (ID * 5 * 6 / (IE + 20 + 30)) + 200",
+                "(((((cast(IA as DOUBLE) / ((cast(IB as DOUBLE) + (cast(IC as 
DOUBLE) * 4.0)) + 36.0)) * cast(ID as DOUBLE)) / cast((IE + 50) as DOUBLE)) * 
60.0) + 201.0)");
     }
 
     @Test
@@ -69,18 +76,24 @@ class SimplifyArithmeticRuleTest extends 
ExpressionRewriteTestHelper {
         assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 
4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))");
 
         // multiply and divide
-        assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", 
"((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as 
DOUBLE)) * cast((IC * 3) as DOUBLE))");
-        assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", 
"(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / 
((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))");
-        assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", 
"(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / 
((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))");
-        assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", 
"(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE)) 
/ (cast(2 as DOUBLE) / cast(1 as DOUBLE)))");
+        assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))",
+                "(((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) * cast (3 as 
DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast(IC as DOUBLE))");
+        assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))",
+                "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as 
DOUBLE)) / (((cast(2 as DOUBLE) * cast(1 as DOUBLE)) / cast(3 as DOUBLE)) / 
cast(4 as DOUBLE)))");
+        assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))",
+                "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as 
DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) * 
cast(4 as DOUBLE)))");
+        assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))",
+                "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast(IC as 
DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) / 
cast(4 as DOUBLE)))");
 
         // hybrid
         // root is subtract
-        assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", 
"(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - 
(cast(3 as DOUBLE) / cast(IC as DOUBLE)))))");
-        assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 
4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 
as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))");
+        assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))",
+                "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * ((cast(1 as 
DOUBLE) - cast(IB as DOUBLE)) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))");
+        assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))",
+                "((((cast(0 as DOUBLE) - cast(2 as DOUBLE)) - cast(IA as 
DOUBLE)) - cast((IB * 1) as DOUBLE)) + (cast(IC as DOUBLE) * (cast(3 as DOUBLE) 
/ cast(4 as DOUBLE))))");
         // root is add
         assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 
4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3 
- 4) - IC) as DOUBLE)))");
-        assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 
4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))");
+        assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 
4)))", "(((((0 + 2) - 1) - IA) + IB) - ((IC + 4) * 3))");
         // root is multiply
         assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 
4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as 
DOUBLE)) / cast(2 as DOUBLE))");
         assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 
4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) / 
cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))");


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to