This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 11ba26c790e [opt](Nereids) aggregate function sum support string type 
as parameter (#49954)
11ba26c790e is described below

commit 11ba26c790e725bde873b7d7af6e9e5996215ad6
Author: morrySnow <zhangwen...@selectdb.com>
AuthorDate: Fri Apr 18 15:14:35 2025 +0800

    [opt](Nereids) aggregate function sum support string type as parameter 
(#49954)
    
    ### What problem does this PR solve?
    
    will cast string to double
    
    ### Release note
    
    aggregate function sum and avg support string type as parameter
---
 .../trees/expressions/functions/agg/Avg.java       | 26 +++++++----
 .../trees/expressions/functions/agg/Sum.java       | 29 +++++++-----
 .../trees/expressions/functions/agg/Sum0.java      | 29 +++++++-----
 .../analysis/CheckExpressionLegalityTest.java      |  2 +-
 .../nereids/trees/expressions/GetDataTypeTest.java | 54 ++++++++++++++++++++--
 .../nereids_function_p0/agg_function/agg.groovy    | 28 +++++++++++
 6 files changed, 131 insertions(+), 37 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
index db1d8d7eb7c..e54b57f15ec 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
@@ -33,6 +33,7 @@ import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 import org.apache.doris.qe.ConnectContext;
@@ -49,14 +50,14 @@ public class Avg extends NullableAggregateFunction
         implements UnaryExpression, ExplicitlyCastableSignature, 
ComputePrecision, SupportWindowAnalytic {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
             
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
             
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
-            
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD)
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE)
     );
 
     /**
@@ -80,8 +81,9 @@ public class Avg extends NullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if (((!argType.isNumericType() && !argType.isNullType()) || 
argType.isOnlyMetricType())) {
-            throw new AnalysisException("avg requires a numeric parameter: " + 
toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("avg requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -153,4 +155,12 @@ public class Avg extends NullableAggregateFunction
     public List<FunctionSignature> getSignatures() {
         return SIGNATURES;
     }
+
+    @Override
+    public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE);
+        }
+        return ExplicitlyCastableSignature.super.searchSignature(signatures);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
index b6a8cd86566..d1c862f5de4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
@@ -29,11 +29,13 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.FloatType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 
@@ -50,14 +52,15 @@ public class Sum extends NullableAggregateFunction
         RollUpTrait, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
             
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
     );
 
     /**
@@ -88,9 +91,9 @@ public class Sum extends NullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if ((!argType.isNumericType() && !argType.isBooleanType() && 
!argType.isNullType())
-                || argType.isOnlyMetricType()) {
-            throw new AnalysisException("sum requires a numeric or boolean 
parameter: " + this.toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("sum requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -120,8 +123,10 @@ public class Sum extends NullableAggregateFunction
 
     @Override
     public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
-        if (getArgument(0).getDataType() instanceof FloatType) {
-            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
+        } else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
+            return 
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
         }
         return ExplicitlyCastableSignature.super.searchSignature(signatures);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
index 9d220237a69..e02139420d2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
@@ -33,11 +33,13 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.FloatType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 
@@ -57,14 +59,15 @@ public class Sum0 extends NotNullableAggregateFunction
         SupportWindowAnalytic, RollUpTrait, SupportMultiDistinct {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
             
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
     );
 
     /**
@@ -91,9 +94,9 @@ public class Sum0 extends NotNullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if ((!argType.isNumericType() && !argType.isBooleanType() && 
!argType.isNullType())
-                || argType.isOnlyMetricType()) {
-            throw new AnalysisException("sum0 requires a numeric or boolean 
parameter: " + this.toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("sum0 requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -118,8 +121,10 @@ public class Sum0 extends NotNullableAggregateFunction
 
     @Override
     public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
-        if (getArgument(0).getDataType() instanceof FloatType) {
-            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
+        } else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
+            return 
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
         }
         return ExplicitlyCastableSignature.super.searchSignature(signatures);
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
index 2b0ae34dc37..34beb21f440 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
@@ -35,7 +35,7 @@ public class CheckExpressionLegalityTest implements 
MemoPatternMatchSupported {
     public void testAvg() {
         ConnectContext connectContext = MemoTestUtils.createConnectContext();
         ExceptionChecker.expectThrowsWithMsg(
-                AnalysisException.class, "avg requires a numeric parameter", 
() -> {
+                AnalysisException.class, "avg requires a numeric", () -> {
                     PlanChecker.from(connectContext)
                             .analyze("select avg(id) from (select to_bitmap(1) 
id) tbl");
                 });
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
index 05824d32802..e95b0cd4b4d 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
@@ -17,13 +17,16 @@
 
 package org.apache.doris.nereids.trees.expressions;
 
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
 import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@@ -35,6 +38,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.LargeIntType;
@@ -57,6 +61,7 @@ public class GetDataTypeTest {
     FloatLiteral floatLiteral = new FloatLiteral(1.0F);
     DoubleLiteral doubleLiteral = new DoubleLiteral(1.0);
     DecimalLiteral decimalLiteral = new DecimalLiteral(BigDecimal.ONE);
+    DecimalV3Literal decimalV3Literal = new DecimalV3Literal(new 
BigDecimal("123.123456"));
     CharLiteral charLiteral = new CharLiteral("hello", 5);
     VarcharLiteral varcharLiteral = new VarcharLiteral("hello", 5);
     StringLiteral stringLiteral = new StringLiteral("hello");
@@ -75,14 +80,55 @@ public class GetDataTypeTest {
         Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(floatLiteral)));
         Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(doubleLiteral)));
         Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), 
checkAndGetDataType(new Sum(decimalLiteral)));
-        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum(bigIntLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(charLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(varcharLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(stringLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Sum(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(stringLiteral)));
         Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(dateLiteral)));
         Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(dateTimeLiteral)));
     }
 
+    @Test
+    public void testSum0() {
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(nullLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(booleanLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(tinyIntLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(smallIntLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(integerLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(bigIntLiteral)));
+        Assertions.assertEquals(LargeIntType.INSTANCE, checkAndGetDataType(new 
Sum0(largeIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(floatLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(doubleLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), 
checkAndGetDataType(new Sum0(decimalLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Sum0(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(stringLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum0(dateLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum0(dateTimeLiteral)));
+    }
+
+    @Test
+    public void testAvg() {
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(nullLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(booleanLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(tinyIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(smallIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(integerLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(bigIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(largeIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(floatLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(doubleLiteral)));
+        Assertions.assertEquals(DecimalV2Type.createDecimalV2Type(27, 9), 
checkAndGetDataType(new Avg(decimalLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Avg(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(bigIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(stringLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Avg(dateLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Avg(dateTimeLiteral)));
+    }
+
     private DataType checkAndGetDataType(Expression expression) {
         expression.checkLegalityBeforeTypeCoercion();
         expression.checkLegalityAfterRewrite();
diff --git a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy 
b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
index df4bea4227c..e581626b574 100644
--- a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
+++ b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
@@ -2515,6 +2515,20 @@ suite("nereids_agg_fn") {
        qt_sql_sum_LargeInt_agg_phase_4_notnull '''
                select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, 
TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum(klint) from 
fn_test'''
 
+       // sum on string like
+       explain {
+               sql("select sum(kstr) from fn_test;")
+               contains "partial_sum(cast(kstr as DOUBLE"
+       }
+       explain {
+               sql("select sum(kvchrs3) from fn_test;")
+               contains "partial_sum(cast(kvchrs3 as DOUBLE"
+       }
+       explain {
+               sql("select sum(kchrs3) from fn_test;")
+               contains "partial_sum(cast(kchrs3 as DOUBLE"
+       }
+
        qt_sql_sum0_Boolean '''
                select sum0(kbool) from fn_test'''
        qt_sql_sum0_Boolean_gb '''
@@ -2700,6 +2714,20 @@ suite("nereids_agg_fn") {
        qt_sql_sum0_LargeInt_agg_phase_4_notnull '''
                select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, 
TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum0(klint) from 
fn_test'''
 
+       // sum on string like
+       explain {
+               sql("select sum0(kstr) from fn_test;")
+               contains "partial_sum0(cast(kstr as DOUBLE"
+       }
+       explain {
+               sql("select sum0(kvchrs3) from fn_test;")
+               contains "partial_sum0(cast(kvchrs3 as DOUBLE"
+       }
+       explain {
+               sql("select sum0(kchrs3) from fn_test;")
+               contains "partial_sum0(cast(kchrs3 as DOUBLE"
+       }
+
        qt_sql_topn_Varchar_Integer_gb '''
                select topn(kvchrs1, 3) from fn_test group by kbool order by 
kbool'''
        qt_sql_topn_Varchar_Integer '''


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to