This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 38a62d890fd [opt](nereids) simplify arithmethic handle with mix add/sub/multiply/divide (#45543) 38a62d890fd is described below commit 38a62d890fde29d15cac04ce70326baeba0813af Author: yujun <yu...@selectdb.com> AuthorDate: Fri Jan 3 14:16:11 2025 +0800 [opt](nereids) simplify arithmethic handle with mix add/sub/multiply/divide (#45543) ### What problem does this PR solve? Two optimizations: 1. handle mix add / sub / multiply / divide SimplifyArithmeticRule only handle add-sub, or multiply-divide, but not both of them. for example, if the expression root is add, then only simplify add-sub, but not simplify multiply-divide. for expr a + 10 + (b * 2 * 3 * (c + 4 + 5)) + 20, after fold const and this rule, it will opt as a + (b * (c + 4 + 5) * 2 * 3) + 30, but after this pr it will opt as a + (b * (c+9) * 6) + 30 2. handle cast SimplifyArithmeticRule not handle with cast. for example, for expr cast ( a * 2 * 30 as double) / (cast 10 as double) , after fold const and this rule, it will opt as cast (a * 60 as double) / 10.0, but after this pr it will opt as cast (a as double) * 6.0 --- .../expression/rules/SimplifyArithmeticRule.java | 86 +++++++++++++++------- .../expression/SimplifyArithmeticRuleTest.java | 27 +++++-- 2 files changed, 80 insertions(+), 33 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java index 44d6505b003..6eea495e5cf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java @@ -22,10 +22,12 @@ import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRuleType; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.nereids.util.TypeUtils; import org.apache.doris.nereids.util.Utils; @@ -35,6 +37,7 @@ import com.google.common.collect.Lists; import java.util.List; import java.util.Optional; +import java.util.function.Predicate; /** * Simplify arithmetic rule. @@ -91,6 +94,9 @@ public class SimplifyArithmeticRule implements ExpressionPatternRuleFactory { } // 2. move variables to left side and move constants to right sid. for (Operand operand : flattedExpressions) { + if (operand.expression instanceof BinaryArithmetic) { + operand.expression = simplify((BinaryArithmetic) operand.expression); + } if (operand.expression.isConstant()) { constants.add(operand); } else { @@ -129,45 +135,73 @@ public class SimplifyArithmeticRule implements ExpressionPatternRuleFactory { } } + // isAddOrSub: true for extract only "+" or "-" sub expressions, false for extract only "*" or "/" sub expressions private static List<Operand> flatten(Expression expr, boolean isAddOrSub) { List<Operand> result = Lists.newArrayList(); - if (isAddOrSub) { - flattenAddSubtract(true, expr, result); - } else { - flattenMultiplyDivide(true, expr, result); - } + doFlatten(true, expr, isAddOrSub, result, Optional.empty()); return result; } - private static void flattenAddSubtract(boolean flag, Expression expr, List<Operand> result) { - if (TypeUtils.isAddOrSubtract(expr)) { - BinaryArithmetic arithmetic = (BinaryArithmetic) expr; - flattenAddSubtract(flag, arithmetic.left(), result); - if (TypeUtils.isSubtract(expr) && !flag) { - flattenAddSubtract(true, arithmetic.right(), result); - } else if (TypeUtils.isAdd(expr) && !flag) { - flattenAddSubtract(false, arithmetic.right(), result); + // flag: true for '+' or '*', false for '-' or '/' + // isAddOrSub: true for extract only "+" or "-" sub expressions, false for extract only "*" or "/" sub expressions + private static void doFlatten(boolean flag, Expression expr, boolean isAddOrSub, List<Operand> result, + Optional<DataType> castType) { + // cast (a * 10 as double) * (cast 20 as double) + // => cast(a as double) * (cast 10 as double) * (cast 20 as double) + BinaryArithmetic arithmetic = null; + Predicate<Expression> isPositiveArithmetic = isAddOrSub + ? TypeUtils::isAdd : TypeUtils::isMultiply; + Predicate<Expression> isNegativeArithmetic = isAddOrSub + ? TypeUtils::isSubtract : TypeUtils::isDivide; + Predicate<Expression> isPosNegArithmetic = isPositiveArithmetic.or(isNegativeArithmetic); + if (isPosNegArithmetic.test(expr)) { + arithmetic = (BinaryArithmetic) expr; + } else if (expr instanceof Cast && hasConstantOperand(expr, isAddOrSub)) { + Cast cast = (Cast) expr; + if (isPosNegArithmetic.test(cast.child())) { + arithmetic = (BinaryArithmetic) cast.child(); + castType = Optional.of(cast.getDataType()); + } + } + if (arithmetic != null) { + doFlatten(flag, arithmetic.left(), isAddOrSub, result, castType); + if (isNegativeArithmetic.test(arithmetic) && !flag) { + doFlatten(true, arithmetic.right(), isAddOrSub, result, castType); + } else if (isPositiveArithmetic.test(arithmetic) && !flag) { + doFlatten(false, arithmetic.right(), isAddOrSub, result, castType); } else { - flattenAddSubtract(!TypeUtils.isSubtract(expr), arithmetic.right(), result); + doFlatten(!isNegativeArithmetic.test(arithmetic), arithmetic.right(), isAddOrSub, result, castType); } } else { - result.add(Operand.of(flag, expr)); + if (castType.isPresent()) { + result.add(Operand.of(flag, TypeCoercionUtils.castIfNotSameType(expr, castType.get()))); + } else { + result.add(Operand.of(flag, expr)); + } } } - private static void flattenMultiplyDivide(boolean flag, Expression expr, List<Operand> result) { - if (TypeUtils.isMultiplyOrDivide(expr)) { - BinaryArithmetic arithmetic = (BinaryArithmetic) expr; - flattenMultiplyDivide(flag, arithmetic.left(), result); - if (TypeUtils.isDivide(expr) && !flag) { - flattenMultiplyDivide(true, arithmetic.right(), result); - } else if (TypeUtils.isMultiply(expr) && !flag) { - flattenMultiplyDivide(false, arithmetic.right(), result); - } else { - flattenMultiplyDivide(!TypeUtils.isDivide(expr), arithmetic.right(), result); + private static boolean hasConstantOperand(Expression expr, boolean isAddOrSub) { + if (expr.isConstant()) { + return true; + } + + Predicate<Expression> checkArithmetic = isAddOrSub + ? TypeUtils::isAddOrSubtract : TypeUtils::isMultiplyOrDivide; + BinaryArithmetic arithmetic = null; + if (checkArithmetic.test(expr)) { + arithmetic = (BinaryArithmetic) expr; + } else if (expr instanceof Cast) { + Cast cast = (Cast) expr; + if (checkArithmetic.test(cast.child())) { + arithmetic = (BinaryArithmetic) cast.child(); } + } + if (arithmetic != null) { + return hasConstantOperand(arithmetic.left(), isAddOrSub) + || hasConstantOperand(arithmetic.right(), isAddOrSub); } else { - result.add(Operand.of(flag, expr)); + return false; } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java index f23aefe5267..92eb90e93b0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java @@ -46,12 +46,19 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper { assertRewriteAfterTypeCoercion("IA + 2 - ((1 - IB) - (3 + IC))", "IA + IB + IC + 4"); assertRewriteAfterTypeCoercion("IA * IB + 2 - IC * 2", "(IA * IB) - (IC * 2) + 2"); assertRewriteAfterTypeCoercion("IA * IB", "IA * IB"); + assertRewriteAfterTypeCoercion("IA * IB / 2 * 2", "cast((IA * IB) as DOUBLE) / 1.0"); assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4.0"); assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as DOUBLE) / 4.0"); assertRewriteAfterTypeCoercion("IA * (IB / 2) * 2)", "cast(IA as DOUBLE) * cast(IB as DOUBLE) / 1.0"); assertRewriteAfterTypeCoercion("IA * (IB / 2) * (IC + 1))", "cast(IA as DOUBLE) * cast(IB as DOUBLE) * cast((IC + 1) as DOUBLE) / 2.0"); assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4", "(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) / 4.0)"); + assertRewriteAfterTypeCoercion("-1 + (10 - 20) * (3 - 6) - (100 - 200) * (6 - 3)", "329"); + assertRewriteAfterTypeCoercion("IA - 10 + (IB * 2 * 3) + 20", "IA + (IB * 6) - (-10)"); + assertRewriteAfterTypeCoercion("IA / 10 * (IB - 2 + 3) * 20", "((cast(IA as DOUBLE) * cast((IB - (-1)) as DOUBLE)) / 0.5)"); + assertRewriteAfterTypeCoercion("1 + ((IA * 2 * 3) * 10 / 10)", "((cast(IA as DOUBLE) * 6.0) + 1.0)"); + assertRewriteAfterTypeCoercion("1 + (IA * 2 * 20 / (IB + 5 + (IC * 10 * 20 / 50 + 5 + 6) + 20) / 20) * (ID * 5 * 6 / (IE + 20 + 30)) + 200", + "(((((cast(IA as DOUBLE) / ((cast(IB as DOUBLE) + (cast(IC as DOUBLE) * 4.0)) + 36.0)) * cast(ID as DOUBLE)) / cast((IE + 50) as DOUBLE)) * 60.0) + 201.0)"); } @Test @@ -69,18 +76,24 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper { assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))"); // multiply and divide - assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((IC * 3) as DOUBLE))"); - assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", "(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / ((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); - assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / ((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))"); - assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", + "(((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) * cast (3 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast(IC as DOUBLE))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", + "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as DOUBLE)) / (((cast(2 as DOUBLE) * cast(1 as DOUBLE)) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", + "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) * cast(4 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", + "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast(IC as DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); // hybrid // root is subtract - assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))"); - assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); + assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", + "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * ((cast(1 as DOUBLE) - cast(IB as DOUBLE)) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", + "((((cast(0 as DOUBLE) - cast(2 as DOUBLE)) - cast(IA as DOUBLE)) - cast((IB * 1) as DOUBLE)) + (cast(IC as DOUBLE) * (cast(3 as DOUBLE) / cast(4 as DOUBLE))))"); // root is add assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3 - 4) - IC) as DOUBLE)))"); - assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))"); + assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - ((IC + 4) * 3))"); // root is multiply assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))"); assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org