This is an automated email from the ASF dual-hosted git repository. dataroaring pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new bec367b719d [fix](nereids) fix fold constant return wrong scale of datetime type (#50142) (#50717) bec367b719d is described below commit bec367b719d182ce59767b87e2582ec1ba924ab7 Author: 924060929 <lanhuaj...@selectdb.com> AuthorDate: Tue May 13 22:30:54 2025 +0800 [fix](nereids) fix fold constant return wrong scale of datetime type (#50142) (#50717) cherry pick from #50142 --- .../expression/rules/FoldConstantRuleOnFE.java | 29 ++++++++---- .../rules/SimplifyConditionalFunction.java | 51 +++++++++++++--------- .../doris/nereids/util/TypeCoercionUtils.java | 22 ++++++++++ .../nereids/rules/expression/FoldConstantTest.java | 20 ++++++++- .../rules/SimplifyConditionalFunctionTest.java | 35 +++++++++++++++ 5 files changed, 127 insertions(+), 30 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 2e3b168926c..424673870b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -91,6 +91,7 @@ import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.GlobalVariable; import org.apache.doris.thrift.TUniqueId; @@ -537,6 +538,7 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule @Override public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + CaseWhen originCaseWhen = caseWhen; caseWhen = rewriteChildren(caseWhen, context); Expression newDefault = null; boolean foundNewDefault = false; @@ -562,7 +564,10 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule defaultResult = newDefault; } if (whenClauses.isEmpty()) { - return defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult; + return TypeCoercionUtils.ensureSameResultType( + originCaseWhen, defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult, + context + ); } if (defaultResult == null) { if (caseWhen.getDataType().isNullType()) { @@ -570,21 +575,24 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule // it's safe to return null literal here return new NullLiteral(); } else { - return new CaseWhen(whenClauses); + return TypeCoercionUtils.ensureSameResultType(originCaseWhen, new CaseWhen(whenClauses), context); } } - return new CaseWhen(whenClauses, defaultResult); + return TypeCoercionUtils.ensureSameResultType( + originCaseWhen, new CaseWhen(whenClauses, defaultResult), context + ); } @Override public Expression visitIf(If ifExpr, ExpressionRewriteContext context) { + If originIf = ifExpr; ifExpr = rewriteChildren(ifExpr, context); if (ifExpr.child(0) instanceof NullLiteral || ifExpr.child(0).equals(BooleanLiteral.FALSE)) { - return ifExpr.child(2); + return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(2), context); } else if (ifExpr.child(0).equals(BooleanLiteral.TRUE)) { - return ifExpr.child(1); + return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(1), context); } - return ifExpr; + return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr, context); } @Override @@ -682,17 +690,20 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule @Override public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) { + Nvl originNvl = nvl; + nvl = rewriteChildren(nvl, context); + for (Expression expr : nvl.children()) { if (expr.isLiteral()) { if (!expr.isNullLiteral()) { - return expr; + return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context); } } else { - return nvl; + return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context); } } // all nulls - return nvl.child(0); + return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context); } private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContext context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java index c1c6283e32d..c6f5aed237d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunction.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; @@ -25,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.ImmutableList; @@ -37,9 +39,9 @@ public class SimplifyConditionalFunction implements ExpressionPatternRuleFactory @Override public List<ExpressionPatternMatcher<? extends Expression>> buildRules() { return ImmutableList.of( - matchesType(Coalesce.class).then(SimplifyConditionalFunction::rewriteCoalesce), - matchesType(Nvl.class).then(SimplifyConditionalFunction::rewriteNvl), - matchesType(NullIf.class).then(SimplifyConditionalFunction::rewriteNullIf) + matchesType(Coalesce.class).thenApply(SimplifyConditionalFunction::rewriteCoalesce), + matchesType(Nvl.class).thenApply(SimplifyConditionalFunction::rewriteNvl), + matchesType(NullIf.class).thenApply(SimplifyConditionalFunction::rewriteNullIf) ); } @@ -49,33 +51,38 @@ public class SimplifyConditionalFunction implements ExpressionPatternRuleFactory * coalesce(null,null) => null * coalesce(expr1) => expr1 * */ - private static Expression rewriteCoalesce(Coalesce expression) { - if (1 == expression.arity()) { - return expression.child(0); + private static Expression rewriteCoalesce(ExpressionMatchingContext<Coalesce> ctx) { + Coalesce coalesce = ctx.expr; + if (1 == coalesce.arity()) { + return TypeCoercionUtils.ensureSameResultType(coalesce, coalesce.child(0), ctx.rewriteContext); } - if (!(expression.child(0) instanceof NullLiteral) && expression.child(0).nullable()) { - return expression; + if (!(coalesce.child(0) instanceof NullLiteral) && coalesce.child(0).nullable()) { + return TypeCoercionUtils.ensureSameResultType(coalesce, coalesce, ctx.rewriteContext); } ImmutableList.Builder<Expression> childBuilder = ImmutableList.builder(); - for (int i = 0; i < expression.arity(); i++) { - Expression child = expression.children().get(i); + for (int i = 0; i < coalesce.arity(); i++) { + Expression child = coalesce.children().get(i); if (child instanceof NullLiteral) { continue; } if (!child.nullable()) { - return child; + return TypeCoercionUtils.ensureSameResultType(coalesce, child, ctx.rewriteContext); } else { - for (int j = i; j < expression.arity(); j++) { - childBuilder.add(expression.children().get(j)); + for (int j = i; j < coalesce.arity(); j++) { + childBuilder.add(coalesce.children().get(j)); } break; } } List<Expression> newChildren = childBuilder.build(); if (newChildren.isEmpty()) { - return new NullLiteral(expression.getDataType()); + return TypeCoercionUtils.ensureSameResultType( + coalesce, new NullLiteral(coalesce.getDataType()), ctx.rewriteContext + ); } else { - return expression.withChildren(newChildren); + return TypeCoercionUtils.ensureSameResultType( + coalesce, coalesce.withChildren(newChildren), ctx.rewriteContext + ); } } @@ -83,12 +90,13 @@ public class SimplifyConditionalFunction implements ExpressionPatternRuleFactory * nvl(null,R) => R * nvl(L(not-nullable ),R) => L * */ - private static Expression rewriteNvl(Nvl nvl) { + private static Expression rewriteNvl(ExpressionMatchingContext<Nvl> ctx) { + Nvl nvl = ctx.expr; if (nvl.child(0) instanceof NullLiteral) { - return nvl.child(1); + return TypeCoercionUtils.ensureSameResultType(nvl, nvl.child(1), ctx.rewriteContext); } if (!nvl.child(0).nullable()) { - return nvl.child(0); + return TypeCoercionUtils.ensureSameResultType(nvl, nvl.child(0), ctx.rewriteContext); } return nvl; } @@ -97,9 +105,12 @@ public class SimplifyConditionalFunction implements ExpressionPatternRuleFactory * nullif(null, R) => Null * nullif(L, null) => Null */ - private static Expression rewriteNullIf(NullIf nullIf) { + private static Expression rewriteNullIf(ExpressionMatchingContext<NullIf> ctx) { + NullIf nullIf = ctx.expr; if (nullIf.child(0) instanceof NullLiteral || nullIf.child(1) instanceof NullLiteral) { - return new Nullable(nullIf.child(0)); + return TypeCoercionUtils.ensureSameResultType( + nullIf, new Nullable(nullIf.child(0)), ctx.rewriteContext + ); } else { return nullIf; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 5f6333f1ea2..fc494f410c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -23,6 +23,8 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.BinaryOperator; @@ -151,6 +153,26 @@ public class TypeCoercionUtils { private static final Logger LOG = LogManager.getLogger(TypeCoercionUtils.class); + /** + * ensure the result's data type equals to the originExpr's dataType, + * ATTN: this method usually used in fold constant rule + */ + public static Expression ensureSameResultType( + Expression originExpr, Expression result, ExpressionRewriteContext context) { + DataType originDataType = originExpr.getDataType(); + DataType newDataType = result.getDataType(); + if (originDataType.equals(newDataType)) { + return result; + } + // backend can direct use all string like type without cast + if (originDataType.isStringLikeType() && newDataType.isStringLikeType()) { + return result; + } + return FoldConstantRuleOnFE.PATTERN_MATCH_INSTANCE.visitCast( + new Cast(result, originDataType), context + ); + } + /** * Return Optional.empty() if we cannot do implicit cast. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index 55c7b279cd3..21934a7ed35 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -25,6 +25,8 @@ import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE; +import org.apache.doris.nereids.rules.expression.rules.SimplifyConditionalFunction; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -121,6 +123,15 @@ class FoldConstantTest extends ExpressionRewriteTestHelper { assertRewriteAfterTypeCoercion("case when null = 2 then 1 else 4 end", "4"); assertRewriteAfterTypeCoercion("case when null = 2 then 1 end", "null"); assertRewriteAfterTypeCoercion("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 END"); + + // make sure the case when return datetime(6) + Expression analyzedCaseWhen = ExpressionAnalyzer.analyzeFunction(null, null, PARSER.parseExpression( + "case when true then cast('2025-04-17' as datetime(0)) else cast('2025-04-18 01:02:03.123456' as datetime(6)) end")); + Assertions.assertEquals(DateTimeV2Type.of(6), analyzedCaseWhen.getDataType()); + Assertions.assertEquals(DateTimeV2Type.of(6), ((CaseWhen) analyzedCaseWhen).getWhenClauses().get(0).getResult().getDataType()); + Assertions.assertEquals(DateTimeV2Type.of(6), ((CaseWhen) analyzedCaseWhen).getDefaultValue().get().getDataType()); + Expression foldCaseWhen = executor.rewrite(analyzedCaseWhen, context); + Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldCaseWhen); } @Test @@ -1174,7 +1185,8 @@ class FoldConstantTest extends ExpressionRewriteTestHelper { executor = new ExpressionRuleExecutor(ImmutableList.of( ExpressionAnalyzer.FUNCTION_ANALYZER_RULE, bottomUp( - FoldConstantRule.INSTANCE + FoldConstantRule.INSTANCE, + SimplifyConditionalFunction.INSTANCE ) )); @@ -1182,6 +1194,12 @@ class FoldConstantTest extends ExpressionRewriteTestHelper { assertRewriteExpression("nvl(NULL, NULL)", "NULL"); assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "ifnull(IA, NULL)"); assertRewriteAfterTypeCoercion("nvl(IA, 1)", "ifnull(IA, 1)"); + + Expression foldNvl = executor.rewrite( + PARSER.parseExpression("nvl(cast('2025-04-17' as datetime(0)), cast('2025-04-18 01:02:03.123456' as datetime(6)))"), + context + ); + Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldNvl); } private void assertRewriteExpression(String actualExpression, String expectedExpression) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java index 7ba9cf09ff2..152c0f542e3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyConditionalFunctionTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce; import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf; @@ -26,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; @@ -71,6 +73,18 @@ public class SimplifyConditionalFunctionTest extends ExpressionRewriteTestHelper // coalesce(null, nullable_slot, literal) -> coalesce(nullable_slot, slot, literal) assertRewrite(new Coalesce(slot, nonNullableSlot), new Coalesce(slot, nonNullableSlot)); + + SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false); + // coalesce(null_datetime(0), non-nullable_slot_datetime(6)) + assertRewrite( + new Coalesce(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot), + new Cast(datetimeSlot, DateTimeV2Type.of(6)) + ); + // coalesce(non-nullable_slot_datetime(6), null_datetime(0)) + assertRewrite( + new Coalesce(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))), + new Cast(datetimeSlot, DateTimeV2Type.of(6)) + ); } @Test @@ -92,6 +106,18 @@ public class SimplifyConditionalFunctionTest extends ExpressionRewriteTestHelper // nvl(null, null) -> null assertRewrite(new Nvl(NullLiteral.INSTANCE, NullLiteral.INSTANCE), new NullLiteral(BooleanType.INSTANCE)); + + SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false); + // nvl(null_datetime(0), non-nullable_slot_datetime(6)) + assertRewrite( + new Nvl(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot), + new Cast(datetimeSlot, DateTimeV2Type.of(6)) + ); + // nvl(non-nullable_slot_datetime(6), null_datetime(0)) + assertRewrite( + new Nvl(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))), + new Cast(datetimeSlot, DateTimeV2Type.of(6)) + ); } @Test @@ -108,6 +134,15 @@ public class SimplifyConditionalFunctionTest extends ExpressionRewriteTestHelper // nullif(non-nullable_slot, null) -> non-nullable_slot assertRewrite(new NullIf(nonNullableSlot, NullLiteral.INSTANCE), new Nullable(nonNullableSlot)); + + // nullif(null_datetime(0), null_datetime(6)) -> null_datetime(6) + assertRewrite( + new NullIf( + new NullLiteral(DateTimeV2Type.of(0)), + new NullLiteral(DateTimeV2Type.of(6)) + ), + new Cast(new Nullable(new NullLiteral(DateTimeV2Type.of(0))), DateTimeV2Type.of(6)) + ); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org