This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 987f7552065635a202f0cb49c5a4d83bd235bc4e
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu Apr 25 15:01:55 2024 +0800

    [Fix](nereids) fix rule SimplifyWindowExpression (#34099)
    
    Co-authored-by: feiniaofeiafei <[email protected]>
---
 .../rules/rewrite/SimplifyWindowExpression.java    | 12 +++-
 .../simplify_window_expression.out                 | 67 +++++++++++++---------
 .../simplify_window_expression.groovy              |  3 +
 3 files changed, 53 insertions(+), 29 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyWindowExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyWindowExpression.java
index 872ca789818..c0548a42579 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyWindowExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyWindowExpression.java
@@ -27,10 +27,12 @@ import 
org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -87,11 +89,13 @@ public class SimplifyWindowExpression extends 
OneRewriteRuleFactory {
             if (function instanceof BoundFunction) {
                 BoundFunction boundFunction = (BoundFunction) function;
                 String name = ((BoundFunction) function).getName();
-                if ((name.equals(COUNT) && 
boundFunction.child(0).notNullable())
+                if ((name.equals(COUNT) && checkCount((Count) boundFunction))
                         || REWRRITE_TO_CONST_WINDOW_FUNCTIONS.contains(name)) {
                     projectionsBuilder.add(new Alias(alias.getExprId(), new 
TinyIntLiteral((byte) 1), alias.getName()));
                 } else if (REWRRITE_TO_SLOT_WINDOW_FUNCTIONS.contains(name)) {
-                    projectionsBuilder.add(new Alias(alias.getExprId(), 
boundFunction.child(0), alias.getName()));
+                    projectionsBuilder.add(new Alias(alias.getExprId(),
+                            
TypeCoercionUtils.castIfNotSameType(boundFunction.child(0), 
boundFunction.getDataType()),
+                            alias.getName()));
                 } else {
                     remainWindowExpression.add(expr);
                 }
@@ -120,4 +124,8 @@ public class SimplifyWindowExpression extends 
OneRewriteRuleFactory {
                     window.child(0)));
         }
     }
+
+    private boolean checkCount(Count count) {
+        return count.isCountStar() || count.child(0).notNullable();
+    }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/simplify_window_expression/simplify_window_expression.out
 
b/regression-test/data/nereids_rules_p0/simplify_window_expression/simplify_window_expression.out
index 3befc3dcbb2..e660cd7702c 100644
--- 
a/regression-test/data/nereids_rules_p0/simplify_window_expression/simplify_window_expression.out
+++ 
b/regression-test/data/nereids_rules_p0/simplify_window_expression/simplify_window_expression.out
@@ -119,28 +119,28 @@
 -- !select_avg --
 \N     \N      \N
 \N     \N      \N
-1      1       1
-1      1       1
-2      2       2
-3      3       3
-3      3       3
-4      4       4
-5      5       5
-5      5       5
-7      7       7
+1      1.0     1.0
+1      1.0     1.0
+2      2.0     2.0
+3      3.0     3.0
+3      3.0     3.0
+4      4.0     4.0
+5      5.0     5.0
+5      5.0     5.0
+7      7.0     7.0
 
 -- !more_than_pk --
 \N     \N      \N
 \N     \N      \N
-1      1       1
-1      1       1
-2      2       2
-3      3       3
-3      3       3
-4      4       4
-5      5       5
-5      5       5
-7      7       7
+1      1.0     1.0
+1      1.0     1.0
+2      2.0     2.0
+3      3.0     3.0
+3      3.0     3.0
+4      4.0     4.0
+5      5.0     5.0
+5      5.0     5.0
+7      7.0     7.0
 
 -- !select_last_value_shape --
 PhysicalResultSink
@@ -163,18 +163,31 @@ PhysicalResultSink
 ------filter((mal_test_simplify_window.__DORIS_DELETE_SIGN__ = 0))
 --------PhysicalOlapScan[mal_test_simplify_window]
 
+-- !select_count_star_col1 --
+\N     1       1
+1      1       1
+1      1       1
+2      1       1
+2      1       1
+2      1       1
+3      1       1
+3      1       1
+4      1       1
+6      1       1
+6      1       1
+
 -- !select_upper_plan_use_all_rewrite --
 \N     \N
 \N     \N
-1      1
-1      1
-2      2
-3      3
-3      3
-4      4
-5      5
-5      5
-7      7
+1      1.0
+1      1.0
+2      2.0
+3      3.0
+3      3.0
+4      4.0
+5      5.0
+5      5.0
+7      7.0
 
 -- !select_upper_plan_use_rewrite_and_not_rewrite --
 \N     \N      \N
diff --git 
a/regression-test/suites/nereids_rules_p0/simplify_window_expression/simplify_window_expression.groovy
 
b/regression-test/suites/nereids_rules_p0/simplify_window_expression/simplify_window_expression.groovy
index 11ad672c74f..3e247b2a78f 100644
--- 
a/regression-test/suites/nereids_rules_p0/simplify_window_expression/simplify_window_expression.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/simplify_window_expression/simplify_window_expression.groovy
@@ -78,6 +78,9 @@ suite("simplify_window_expression") {
         explain shape plan
         select b, avg(b) over (partition by a,b,c) c1, avg(b) over (partition 
by a,b,c order by b) c2
         from mal_test_simplify_window"""
+    qt_select_count_star_col1 """
+        select a,count() over (partition by a,b) c1, count() over (partition 
by a,b order by a) c2
+        from mal_test_simplify_window order by 1,2,3;"""
 
     qt_select_upper_plan_use_all_rewrite """
         select b, c1 from (select b,avg(b) over (partition by a,b) c1 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to