This is an automated email from the ASF dual-hosted git repository. gabriellee pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 70f2e8ff80 [fix](nereids)enable decimalv3 by default for nereids (#19906) 70f2e8ff80 is described below commit 70f2e8ff80e08f6f7db13bd0006fe71625f4bb1b Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Wed May 24 13:36:24 2023 +0800 [fix](nereids)enable decimalv3 by default for nereids (#19906) --- .../java/org/apache/doris/catalog/ScalarType.java | 45 +++++++++++++++++++++- .../main/java/org/apache/doris/catalog/Type.java | 8 +++- .../org/apache/doris/analysis/BinaryPredicate.java | 16 +++++--- .../org/apache/doris/analysis/DecimalLiteral.java | 6 ++- .../main/java/org/apache/doris/analysis/Expr.java | 6 +++ .../apache/doris/analysis/FunctionCallExpr.java | 16 ++++---- .../rules/SimplifyDecimalV3Comparison.java | 18 +++++---- .../functions/ComputeSignatureHelper.java | 13 ++++++- .../expressions/functions/SearchSignature.java | 11 +++++- .../expressions/functions/agg/AvgWeighted.java | 6 +-- .../trees/expressions/functions/agg/Histogram.java | 4 +- .../trees/expressions/functions/agg/Stddev.java | 2 +- .../expressions/functions/agg/StddevSamp.java | 2 +- .../trees/expressions/functions/agg/Variance.java | 2 +- .../expressions/functions/agg/VarianceSamp.java | 2 +- .../expressions/functions/scalar/Coalesce.java | 2 +- .../trees/expressions/functions/scalar/If.java | 12 +++--- .../trees/expressions/functions/scalar/Nvl.java | 4 +- .../trees/expressions/literal/DecimalLiteral.java | 7 ++-- .../doris/nereids/util/TypeCoercionUtils.java | 24 ++++++++---- .../doris/rewrite/ExtractCommonFactorsRule.java | 18 ++++----- 21 files changed, 157 insertions(+), 67 deletions(-) diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java index d3d40642aa..5e4a54f948 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java @@ -30,6 +30,7 @@ import com.google.gson.annotations.SerializedName; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -1111,6 +1112,12 @@ public class ScalarType extends Type { return getAssignmentCompatibleDecimalV2Type(t1, t2); } + if ((t1.isDecimalV3() && t2.isDecimalV2()) || (t2.isDecimalV3() && t1.isDecimalV2())) { + int scale = Math.max(t1.scale, t2.scale); + int integerPart = Math.max(t1.precision - t1.scale, t2.precision - t2.scale); + return ScalarType.createDecimalV3Type(integerPart + scale, scale); + } + if (t1.isDecimalV2() || t2.isDecimalV2()) { if (t1.isFloatingPointType() || t2.isFloatingPointType()) { return MAX_DECIMALV2_TYPE; @@ -1118,8 +1125,42 @@ public class ScalarType extends Type { return t1.isDecimalV2() ? t1 : t2; } - if ((t1.isDecimalV3() && t2.isFixedPointType()) || (t2.isDecimalV3() && t1.isFixedPointType())) { - return t1.isDecimalV3() ? t1 : t2; + if (t1.isDecimalV3() || t2.isDecimalV3()) { + if (t1.isFloatingPointType() || t2.isFloatingPointType()) { + return t1.isFloatingPointType() ? t1 : t2; + } else if (t1.isBoolean() || t2.isBoolean()) { + return t1.isDecimalV3() ? t1 : t2; + } + } + + if ((t1.isDecimalV3() && t2.isFixedPointType()) + || (t2.isDecimalV3() && t1.isFixedPointType())) { + int precision; + int scale; + ScalarType intType; + if (t1.isDecimalV3()) { + precision = t1.precision; + scale = t1.scale; + intType = t2; + } else { + precision = t2.precision; + scale = t2.scale; + intType = t1; + } + int integerPart = precision - scale; + if (intType.isScalarType(PrimitiveType.TINYINT) + || intType.isScalarType(PrimitiveType.SMALLINT)) { + integerPart = Math.max(integerPart, new BigDecimal(Short.MAX_VALUE).precision()); + } else if (intType.isScalarType(PrimitiveType.INT)) { + integerPart = Math.max(integerPart, new BigDecimal(Integer.MAX_VALUE).precision()); + } else { + integerPart = ScalarType.MAX_DECIMAL128_PRECISION - scale; + } + if (scale + integerPart <= ScalarType.MAX_DECIMAL128_PRECISION) { + return ScalarType.createDecimalV3Type(scale + integerPart, scale); + } else { + return Type.DOUBLE; + } } if (t1.isDecimalV3() && t2.isDecimalV3()) { diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java index c1bcc9c9ae..979ae38052 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java @@ -1813,8 +1813,12 @@ public abstract class Type { } else { resultDecimalType = PrimitiveType.DECIMAL128; } - return ScalarType.createDecimalType(resultDecimalType, resultPrecision, - Math.max(((ScalarType) t1).getScalarScale(), ((ScalarType) t2).getScalarScale())); + if (resultPrecision <= ScalarType.MAX_DECIMAL128_PRECISION) { + return ScalarType.createDecimalType(resultDecimalType, resultPrecision, Math.max( + ((ScalarType) t1).getScalarScale(), ((ScalarType) t2).getScalarScale())); + } else { + return Type.DOUBLE; + } } if (t1ResultType.isDecimalV3Type() || t2ResultType.isDecimalV3Type()) { return getAssignmentCompatibleType(t1, t2, false); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java index 112460564e..64f802efaf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java @@ -414,9 +414,13 @@ public class BinaryPredicate extends Predicate implements Writable { if (t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.BIGINT) { return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false); } - if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMALV2) - && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMALV2)) { - return Type.DECIMALV2; + if ((t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.DECIMALV2) + || (t2 == PrimitiveType.BIGINT && t1 == PrimitiveType.DECIMALV2) + || (t1 == PrimitiveType.LARGEINT && t2 == PrimitiveType.DECIMALV2) + || (t2 == PrimitiveType.LARGEINT && t1 == PrimitiveType.DECIMALV2)) { + // only decimalv3 can hold big and large int + return ScalarType.createDecimalType(PrimitiveType.DECIMAL128, ScalarType.MAX_DECIMAL128_PRECISION, + ScalarType.MAX_DECIMALV2_SCALE); } if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.LARGEINT) && (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.LARGEINT)) { @@ -603,9 +607,9 @@ public class BinaryPredicate extends Predicate implements Writable { } public Range<LiteralExpr> convertToRange() { - Preconditions.checkState(getChild(0) instanceof SlotRef); - Preconditions.checkState(getChild(1) instanceof LiteralExpr); - LiteralExpr literalExpr = (LiteralExpr) getChild(1); + Preconditions.checkState(getChildWithoutCast(0) instanceof SlotRef); + Preconditions.checkState(getChildWithoutCast(1) instanceof LiteralExpr); + LiteralExpr literalExpr = (LiteralExpr) getChildWithoutCast(1); switch (op) { case EQ: return Range.singleton(literalExpr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index bc2951e432..9040798b76 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -286,8 +286,10 @@ public class DecimalLiteral extends LiteralExpr { @Override protected void compactForLiteral(Type type) throws AnalysisException { if (type.isDecimalV3()) { - this.type = ScalarType.createDecimalV3Type(Math.max(this.value.precision(), type.getPrecision()), - Math.max(this.value.scale(), ((ScalarType) type).decimalScale())); + int scale = Math.max(this.value.scale(), ((ScalarType) type).decimalScale()); + int integerPart = Math.max(this.value.precision() - this.value.scale(), + type.getPrecision() - ((ScalarType) type).decimalScale()); + this.type = ScalarType.createDecimalV3Type(integerPart + scale, scale); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index 091fe7b9b8..18b99065f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -495,6 +495,12 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl return result; } + public Expr getChildWithoutCast(int i) { + Preconditions.checkArgument(i < children.size(), "child index {0} out of range {1}", i, children.size()); + Expr child = children.get(i); + return child instanceof CastExpr ? child.children.get(0) : child; + } + /** * Helper function: analyze list of exprs * diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index 1238c8514e..adffa17b2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1404,23 +1404,24 @@ public class FunctionCallExpr extends Expr { fn = getBuiltinFunction(fnName.getFunction(), childTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if ((fnName.getFunction().equalsIgnoreCase("coalesce") - || fnName.getFunction().equalsIgnoreCase("greatest") - || fnName.getFunction().equalsIgnoreCase("least")) && children.size() > 1) { + || fnName.getFunction().equalsIgnoreCase("least") + || fnName.getFunction().equalsIgnoreCase("greatest")) && children.size() > 1) { Type[] childTypes = collectChildReturnTypes(); Type assignmentCompatibleType = childTypes[0]; - for (int i = 1; i < childTypes.length && assignmentCompatibleType.isDecimalV3(); i++) { - assignmentCompatibleType = - ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true); + for (int i = 1; i < childTypes.length; i++) { + assignmentCompatibleType = ScalarType + .getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], true); } if (assignmentCompatibleType.isDecimalV3()) { for (int i = 0; i < childTypes.length; i++) { - if (assignmentCompatibleType.isDecimalV3() && !childTypes[i].equals(assignmentCompatibleType)) { + if (assignmentCompatibleType.isDecimalV3() + && !childTypes[i].equals(assignmentCompatibleType)) { uncheckedCastChild(assignmentCompatibleType, i); argTypes[i] = assignmentCompatibleType; } } } - fn = getBuiltinFunction(fnName.getFunction(), childTypes, + fn = getBuiltinFunction(fnName.getFunction(), argTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains( fnName.getFunction().toLowerCase())) { @@ -1658,6 +1659,7 @@ public class FunctionCallExpr extends Expr { } else if (!argTypes[i].matchesType(args[ix]) && (!fn.getReturnType().isDecimalV3OrContainsDecimalV3() || (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) { + // || (argTypes[i].isValid() && argTypes[i].getPrimitiveType() != args[ix].getPrimitiveType()))) { uncheckedCastChild(args[ix], i); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java index 93021f0b58..fc1ee0cb91 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -63,15 +63,19 @@ public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule { private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros(); int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); - int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); + int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue); Expression castChild = left.child(); Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type); DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType(); - // precision and scale of literal must all smaller than left, otherwise we need to do cast on right. - Preconditions.checkState(scale <= leftType.getScale(), "right scale should not greater than left"); - Preconditions.checkState(precision <= leftType.getPrecision(), "right precision should not greater than left"); - DecimalV3Literal newRight = new DecimalV3Literal( - DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), trailingZerosValue); - return cp.withChildren(castChild, newRight); + + if (scale <= leftType.getScale() && precision - scale <= leftType.getPrecision() - leftType.getScale()) { + // precision and scale of literal all smaller than left, we don't need the cast + DecimalV3Literal newRight = new DecimalV3Literal( + DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), + trailingZerosValue); + return cp.withChildren(castChild, newRight); + } else { + return cp; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java index a59039a934..fba33d0fe0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.catalog.FunctionSignature.TripleFunction; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeV2Type; @@ -32,6 +33,7 @@ import org.apache.doris.nereids.util.ResponsibilityChain; import com.google.common.base.Preconditions; +import java.math.BigDecimal; import java.util.List; import java.util.function.BiFunction; import java.util.stream.Collectors; @@ -160,8 +162,15 @@ public class ComputeSignatureHelper { if (finalType == null) { finalType = DecimalV3Type.forType(arguments.get(i).getDataType()); } else { - finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, - DecimalV3Type.forType(arguments.get(i).getDataType()), true); + Expression arg = arguments.get(i); + DecimalV3Type argType; + if (arg.isLiteral() && arg.getDataType().isIntegralType()) { + // create decimalV3 with minimum scale enough to hold the integral literal + argType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue())); + } else { + argType = DecimalV3Type.forType(arg.getDataType()); + } + finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, argType, true); } Preconditions.checkState(finalType.isDecimalV3Type(), "decimalv3 precision promotion failed."); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/SearchSignature.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/SearchSignature.java index 3a462940d9..9e2fa281ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/SearchSignature.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/SearchSignature.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.Lists; +import java.math.BigDecimal; import java.util.List; import java.util.Optional; import java.util.function.BiFunction; @@ -141,8 +142,14 @@ public class SearchSignature { if (finalType == null) { finalType = DecimalV3Type.forType(arguments.get(i).getDataType()); } else { - finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, - DecimalV3Type.forType(arguments.get(i).getDataType()), true); + Expression arg = arguments.get(i); + if (arg.isLiteral() && arg.getDataType().isIntegralType()) { + // create decimalV3 with minimum scale enough to hold the integral literal + finalType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue())); + } else { + finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, + DecimalV3Type.forType(arg.getDataType()), true); + } } if (!finalType.isDecimalV3Type()) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AvgWeighted.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AvgWeighted.java index 5911c6d360..e2054878d9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AvgWeighted.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AvgWeighted.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DecimalV2Type; -import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; @@ -44,14 +43,13 @@ public class AvgWeighted extends AggregateFunction implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV3Type.WILDCARD, DoubleType.INSTANCE) + FunctionSignature.ret(DoubleType.INSTANCE).args(DecimalV2Type.SYSTEM_DEFAULT, DoubleType.INSTANCE) ); /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java index 36053ae0cc..9eeb709810 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java @@ -31,6 +31,7 @@ import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; @@ -60,7 +61,8 @@ public class Histogram extends AggregateFunction FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(LargeIntType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(FloatType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DoubleType.INSTANCE), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.CATALOG_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV3Type.WILDCARD), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateV2Type.INSTANCE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Stddev.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Stddev.java index ecab5989e9..6cbebbc0ec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Stddev.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Stddev.java @@ -44,12 +44,12 @@ public class Stddev extends NullableAggregateFunction StdDevOrVarianceFunction, DecimalStddevPrecision { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/StddevSamp.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/StddevSamp.java index 973609a2b2..971af51a57 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/StddevSamp.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/StddevSamp.java @@ -45,12 +45,12 @@ public class StddevSamp extends AggregateFunction StdDevOrVarianceFunction, DecimalStddevPrecision { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Variance.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Variance.java index e0a8877888..79fbcfb764 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Variance.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Variance.java @@ -44,12 +44,12 @@ public class Variance extends NullableAggregateFunction StdDevOrVarianceFunction, DecimalStddevPrecision { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/VarianceSamp.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/VarianceSamp.java index 3473b2afba..8196be54ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/VarianceSamp.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/VarianceSamp.java @@ -44,12 +44,12 @@ public class VarianceSamp extends AggregateFunction StdDevOrVarianceFunction, AlwaysNullable { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java index 71b0d8e698..7dca8a7f2e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Coalesce.java @@ -64,8 +64,8 @@ public class Coalesce extends ScalarFunction FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE), FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD), + FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(BitmapType.INSTANCE).varArgs(BitmapType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT), FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java index 75a40f3cea..df8bc78a27 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java @@ -82,10 +82,10 @@ public class If extends ScalarFunction FunctionSignature.ret(DateTimeType.INSTANCE) .args(BooleanType.INSTANCE, DateTimeType.INSTANCE, DateTimeType.INSTANCE), FunctionSignature.ret(DateType.INSTANCE).args(BooleanType.INSTANCE, DateType.INSTANCE, DateType.INSTANCE), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT) - .args(BooleanType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DecimalV3Type.WILDCARD) .args(BooleanType.INSTANCE, DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD), + FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT) + .args(BooleanType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(BitmapType.INSTANCE) .args(BooleanType.INSTANCE, BitmapType.INSTANCE, BitmapType.INSTANCE), FunctionSignature.ret(HllType.INSTANCE).args(BooleanType.INSTANCE, HllType.INSTANCE, HllType.INSTANCE), @@ -125,14 +125,14 @@ public class If extends ScalarFunction ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT)), FunctionSignature.ret(ArrayType.of(DateV2Type.INSTANCE)) .args(BooleanType.INSTANCE, ArrayType.of(DateV2Type.INSTANCE), ArrayType.of(DateV2Type.INSTANCE)), - FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) - .args(BooleanType.INSTANCE, - ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT), - ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .args(BooleanType.INSTANCE, ArrayType.of(DecimalV3Type.WILDCARD), ArrayType.of(DecimalV3Type.WILDCARD)), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) + .args(BooleanType.INSTANCE, + ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT), + ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .args(BooleanType.INSTANCE, ArrayType.of(VarcharType.SYSTEM_DEFAULT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java index e0b772d231..cf553199fe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java @@ -68,10 +68,10 @@ public class Nvl extends ScalarFunction .args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DateV2Type.INSTANCE) .args(DateV2Type.INSTANCE, DateV2Type.INSTANCE), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT) - .args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DecimalV3Type.WILDCARD) .args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD), + FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT) + .args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE, BitmapType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java index 10a97a71ed..5cde2e155f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java @@ -37,8 +37,9 @@ public class DecimalLiteral extends Literal { } public DecimalLiteral(DecimalV2Type dataType, BigDecimal value) { - super(DecimalV2Type.createDecimalV2Type(dataType.getPrecision(), dataType.getScale())); - this.value = Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN)); + super(dataType); + BigDecimal adjustedValue = value.scale() < 0 ? value : value.setScale(dataType.getScale(), RoundingMode.DOWN); + this.value = Objects.requireNonNull(adjustedValue); } @Override @@ -53,7 +54,7 @@ public class DecimalLiteral extends Literal { @Override public LiteralExpr toLegacyLiteral() { - return new org.apache.doris.analysis.DecimalLiteral(value); + return new org.apache.doris.analysis.DecimalLiteral(value, dataType.toCatalogDataType()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index cfdbfa684e..e78286bc29 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -451,8 +451,8 @@ public class TypeCoercionUtils { } DataType commonType = DoubleType.INSTANCE; - if (t1.isDoubleType() || t1.isFloatType() || t1.isLargeIntType() - || t2.isDoubleType() || t2.isFloatType() || t2.isLargeIntType()) { + if (t1.isDoubleType() || t1.isFloatType() + || t2.isDoubleType() || t2.isFloatType()) { // double type } else if (t1.isDecimalV3Type() || t2.isDecimalV3Type()) { // divide should cast to precision and target scale @@ -535,6 +535,9 @@ public class TypeCoercionUtils { break; } } + if (commonType.isFloatType() && (t1.isDecimalV3Type() || t2.isDecimalV3Type())) { + commonType = DoubleType.INSTANCE; + } boolean isBitArithmetic = binaryArithmetic instanceof BitAnd || binaryArithmetic instanceof BitOr @@ -577,13 +580,12 @@ public class TypeCoercionUtils { return castChildren(binaryArithmetic, left, right, DoubleType.INSTANCE); } - // add, subtract and mod should cast children to exactly same type as return type + // add, subtract should cast children to exactly same type as return type if (binaryArithmetic instanceof Add - || binaryArithmetic instanceof Subtract - || binaryArithmetic instanceof Mod) { + || binaryArithmetic instanceof Subtract) { return castChildren(binaryArithmetic, left, right, retType); } - // multiply do not need to cast children to same type + // multiply and mode do not need to cast children to same type return binaryArithmetic.withChildren(castIfNotSameType(left, dt1), castIfNotSameType(right, dt2)); } @@ -959,8 +961,16 @@ public class TypeCoercionUtils { DecimalV3Type.forType(leftType), DecimalV3Type.forType(rightType), true)); } if (leftType instanceof DecimalV2Type || rightType instanceof DecimalV2Type) { - return Optional.of(DecimalV2Type.widerDecimalV2Type( + if (leftType instanceof BigIntType || rightType instanceof BigIntType + || leftType instanceof LargeIntType || rightType instanceof LargeIntType) { + // only decimalv3 can hold big or large int + return Optional + .of(DecimalV3Type.widerDecimalV3Type(DecimalV3Type.forType(leftType), + DecimalV3Type.forType(rightType), true)); + } else { + return Optional.of(DecimalV2Type.widerDecimalV2Type( DecimalV2Type.forType(leftType), DecimalV2Type.forType(rightType))); + } } return Optional.of(commonType); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java index d28fde13ea..dcc6cd4a18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/ExtractCommonFactorsRule.java @@ -229,7 +229,7 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { if (!singleColumnPredicate(predicate)) { continue; } - SlotRef columnName = (SlotRef) predicate.getChild(0); + SlotRef columnName = (SlotRef) predicate.getChildWithoutCast(0); if (predicate instanceof BinaryPredicate) { Range<LiteralExpr> predicateRange = ((BinaryPredicate) predicate).convertToRange(); if (predicateRange == null) { @@ -319,14 +319,14 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { if (inPredicate.isNotIn()) { return false; } - if (inPredicate.getChild(0) instanceof SlotRef) { + if (inPredicate.getChildWithoutCast(0) instanceof SlotRef) { return true; } return false; } else if (expr instanceof BinaryPredicate) { BinaryPredicate binaryPredicate = (BinaryPredicate) expr; - if (binaryPredicate.getChild(0) instanceof SlotRef - && binaryPredicate.getChild(1) instanceof LiteralExpr) { + if (binaryPredicate.getChildWithoutCast(0) instanceof SlotRef + && binaryPredicate.getChildWithoutCast(1) instanceof LiteralExpr) { return true; } return false; @@ -518,9 +518,9 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { notMergedExprs.add(new CompoundPredicate(Operator.AND, left, right)); } else if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) { notMergedExprs.add(predicate); - } else if (!(predicate.getChild(0) instanceof SlotRef)) { + } else if (!(predicate.getChildWithoutCast(0) instanceof SlotRef)) { notMergedExprs.add(predicate); - } else if (!(predicate.getChild(1) instanceof LiteralExpr)) { + } else if (!(predicate.getChildWithoutCast(1) instanceof LiteralExpr)) { notMergedExprs.add(predicate); } else if (predicate instanceof BinaryPredicate && ((BinaryPredicate) predicate).getOp() != BinaryPredicate.Operator.EQ) { @@ -529,13 +529,13 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule { && ((InPredicate) predicate).isNotIn()) { notMergedExprs.add(predicate); } else { - TableName tableName = ((SlotRef) predicate.getChild(0)).getTableName(); + TableName tableName = ((SlotRef) predicate.getChildWithoutCast(0)).getTableName(); String columnWithTable; if (tableName != null) { String tblName = tableName.toString(); - columnWithTable = tblName + "." + ((SlotRef) predicate.getChild(0)).getColumnName(); + columnWithTable = tblName + "." + ((SlotRef) predicate.getChildWithoutCast(0)).getColumnName(); } else { - columnWithTable = ((SlotRef) predicate.getChild(0)).getColumnName(); + columnWithTable = ((SlotRef) predicate.getChildWithoutCast(0)).getColumnName(); } slotNameToMergeExprsMap.computeIfAbsent(columnWithTable, key -> { slotNameForMerge.add(columnWithTable); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org