This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new e29aeb5373e branch-3.1: [opt](Nereids) aggregate function sum support 
string type as parameter #49954 (#51963)
e29aeb5373e is described below

commit e29aeb5373e7897d5b0dd194e307e38ee641b1e9
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Fri Jun 20 10:57:37 2025 +0800

    branch-3.1: [opt](Nereids) aggregate function sum support string type as 
parameter #49954 (#51963)
    
    Cherry-picked from #49954
    
    Co-authored-by: morrySnow <[email protected]>
---
 .../trees/expressions/functions/agg/Avg.java       | 26 +++++++----
 .../trees/expressions/functions/agg/Sum.java       | 29 +++++++-----
 .../trees/expressions/functions/agg/Sum0.java      | 29 +++++++-----
 .../analysis/CheckExpressionLegalityTest.java      |  2 +-
 .../nereids/trees/expressions/GetDataTypeTest.java | 54 ++++++++++++++++++++--
 .../nereids_function_p0/agg_function/agg.groovy    | 28 +++++++++++
 6 files changed, 131 insertions(+), 37 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
index db1d8d7eb7c..e54b57f15ec 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Avg.java
@@ -33,6 +33,7 @@ import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 import org.apache.doris.qe.ConnectContext;
@@ -49,14 +50,14 @@ public class Avg extends NullableAggregateFunction
         implements UnaryExpression, ExplicitlyCastableSignature, 
ComputePrecision, SupportWindowAnalytic {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
             
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
             
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
-            
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD)
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE)
     );
 
     /**
@@ -80,8 +81,9 @@ public class Avg extends NullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if (((!argType.isNumericType() && !argType.isNullType()) || 
argType.isOnlyMetricType())) {
-            throw new AnalysisException("avg requires a numeric parameter: " + 
toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("avg requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -153,4 +155,12 @@ public class Avg extends NullableAggregateFunction
     public List<FunctionSignature> getSignatures() {
         return SIGNATURES;
     }
+
+    @Override
+    public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE);
+        }
+        return ExplicitlyCastableSignature.super.searchSignature(signatures);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
index e55f926ae4d..b5616dad15c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java
@@ -29,11 +29,13 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.FloatType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 
@@ -50,14 +52,15 @@ public class Sum extends NullableAggregateFunction
         RollUpTrait {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
             
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
     );
 
     /**
@@ -87,9 +90,9 @@ public class Sum extends NullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if ((!argType.isNumericType() && !argType.isBooleanType() && 
!argType.isNullType())
-                || argType.isOnlyMetricType()) {
-            throw new AnalysisException("sum requires a numeric or boolean 
parameter: " + this.toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("sum requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -119,8 +122,10 @@ public class Sum extends NullableAggregateFunction
 
     @Override
     public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
-        if (getArgument(0).getDataType() instanceof FloatType) {
-            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
+        } else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
+            return 
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
         }
         return ExplicitlyCastableSignature.super.searchSignature(signatures);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
index 5a1f0f9fb93..7c3873de01f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java
@@ -33,11 +33,13 @@ import 
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.FloatType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.NullType;
 import org.apache.doris.nereids.types.SmallIntType;
 import org.apache.doris.nereids.types.TinyIntType;
 
@@ -57,14 +59,15 @@ public class Sum0 extends NotNullableAggregateFunction
         SupportWindowAnalytic, RollUpTrait {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
-            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
-            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
+            
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE),
             
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
-            
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
+            
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
+            
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE)
     );
 
     /**
@@ -90,9 +93,9 @@ public class Sum0 extends NotNullableAggregateFunction
     @Override
     public void checkLegalityBeforeTypeCoercion() {
         DataType argType = child().getDataType();
-        if ((!argType.isNumericType() && !argType.isBooleanType() && 
!argType.isNullType())
-                || argType.isOnlyMetricType()) {
-            throw new AnalysisException("sum0 requires a numeric or boolean 
parameter: " + this.toSql());
+        if (!argType.isNumericType() && !argType.isBooleanType()
+                && !argType.isNullType() && !argType.isStringLikeType()) {
+            throw new AnalysisException("sum0 requires a numeric, boolean or 
string parameter: " + this.toSql());
         }
     }
 
@@ -117,8 +120,10 @@ public class Sum0 extends NotNullableAggregateFunction
 
     @Override
     public FunctionSignature searchSignature(List<FunctionSignature> 
signatures) {
-        if (getArgument(0).getDataType() instanceof FloatType) {
-            return 
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE);
+        if (getArgument(0).getDataType() instanceof NullType) {
+            return 
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE);
+        } else if (getArgument(0).getDataType() instanceof DecimalV2Type) {
+            return 
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD);
         }
         return ExplicitlyCastableSignature.super.searchSignature(signatures);
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
index 2b0ae34dc37..34beb21f440 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckExpressionLegalityTest.java
@@ -35,7 +35,7 @@ public class CheckExpressionLegalityTest implements 
MemoPatternMatchSupported {
     public void testAvg() {
         ConnectContext connectContext = MemoTestUtils.createConnectContext();
         ExceptionChecker.expectThrowsWithMsg(
-                AnalysisException.class, "avg requires a numeric parameter", 
() -> {
+                AnalysisException.class, "avg requires a numeric", () -> {
                     PlanChecker.from(connectContext)
                             .analyze("select avg(id) from (select to_bitmap(1) 
id) tbl");
                 });
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
index 05824d32802..e95b0cd4b4d 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java
@@ -17,13 +17,16 @@
 
 package org.apache.doris.nereids.trees.expressions;
 
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
 import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@@ -35,6 +38,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.LargeIntType;
@@ -57,6 +61,7 @@ public class GetDataTypeTest {
     FloatLiteral floatLiteral = new FloatLiteral(1.0F);
     DoubleLiteral doubleLiteral = new DoubleLiteral(1.0);
     DecimalLiteral decimalLiteral = new DecimalLiteral(BigDecimal.ONE);
+    DecimalV3Literal decimalV3Literal = new DecimalV3Literal(new 
BigDecimal("123.123456"));
     CharLiteral charLiteral = new CharLiteral("hello", 5);
     VarcharLiteral varcharLiteral = new VarcharLiteral("hello", 5);
     StringLiteral stringLiteral = new StringLiteral("hello");
@@ -75,14 +80,55 @@ public class GetDataTypeTest {
         Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(floatLiteral)));
         Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(doubleLiteral)));
         Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), 
checkAndGetDataType(new Sum(decimalLiteral)));
-        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum(bigIntLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(charLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(varcharLiteral)));
-        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(stringLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Sum(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum(stringLiteral)));
         Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(dateLiteral)));
         Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum(dateTimeLiteral)));
     }
 
+    @Test
+    public void testSum0() {
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(nullLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(booleanLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(tinyIntLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(smallIntLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(integerLiteral)));
+        Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new 
Sum0(bigIntLiteral)));
+        Assertions.assertEquals(LargeIntType.INSTANCE, checkAndGetDataType(new 
Sum0(largeIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(floatLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(doubleLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), 
checkAndGetDataType(new Sum0(decimalLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Sum0(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Sum0(stringLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum0(dateLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Sum0(dateTimeLiteral)));
+    }
+
+    @Test
+    public void testAvg() {
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(nullLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(booleanLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(tinyIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(smallIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(integerLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(bigIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(largeIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(floatLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(doubleLiteral)));
+        Assertions.assertEquals(DecimalV2Type.createDecimalV2Type(27, 9), 
checkAndGetDataType(new Avg(decimalLiteral)));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 6), 
checkAndGetDataType(new Avg(decimalV3Literal)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(bigIntLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(charLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(varcharLiteral)));
+        Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new 
Avg(stringLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Avg(dateLiteral)));
+        Assertions.assertThrows(RuntimeException.class, () -> 
checkAndGetDataType(new Avg(dateTimeLiteral)));
+    }
+
     private DataType checkAndGetDataType(Expression expression) {
         expression.checkLegalityBeforeTypeCoercion();
         expression.checkLegalityAfterRewrite();
diff --git a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy 
b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
index ba4ee172517..3e10240ea61 100644
--- a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
+++ b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
@@ -2338,6 +2338,20 @@ suite("nereids_agg_fn") {
        qt_sql_sum_LargeInt_agg_phase_4_notnull '''
                select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, 
TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum(klint) from 
fn_test'''
 
+       // sum on string like
+       explain {
+               sql("select sum(kstr) from fn_test;")
+               contains "partial_sum(cast(kstr as DOUBLE"
+       }
+       explain {
+               sql("select sum(kvchrs3) from fn_test;")
+               contains "partial_sum(cast(kvchrs3 as DOUBLE"
+       }
+       explain {
+               sql("select sum(kchrs3) from fn_test;")
+               contains "partial_sum(cast(kchrs3 as DOUBLE"
+       }
+
        qt_sql_sum0_Boolean '''
                select sum0(kbool) from fn_test'''
        qt_sql_sum0_Boolean_gb '''
@@ -2523,6 +2537,20 @@ suite("nereids_agg_fn") {
        qt_sql_sum0_LargeInt_agg_phase_4_notnull '''
                select 
/*+SET_VAR(disable_nereids_rules='THREE_PHASE_AGGREGATE_WITH_DISTINCT, 
TWO_PHASE_AGGREGATE_WITH_DISTINCT')*/ count(distinct id), sum0(klint) from 
fn_test'''
 
+       // sum on string like
+       explain {
+               sql("select sum0(kstr) from fn_test;")
+               contains "partial_sum0(cast(kstr as DOUBLE"
+       }
+       explain {
+               sql("select sum0(kvchrs3) from fn_test;")
+               contains "partial_sum0(cast(kvchrs3 as DOUBLE"
+       }
+       explain {
+               sql("select sum0(kchrs3) from fn_test;")
+               contains "partial_sum0(cast(kchrs3 as DOUBLE"
+       }
+
        qt_sql_topn_Varchar_Integer_gb '''
                select topn(kvchrs1, 3) from fn_test group by kbool order by 
kbool'''
        qt_sql_topn_Varchar_Integer '''


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to