This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 385d9dbe71 try to remove redundant alias in expression rewriter and 
select (#20867)
385d9dbe71 is described below

commit 385d9dbe7106077b8f07922e9f0c2cf6d0d7ce6e
Author: Burak Şen <[email protected]>
AuthorDate: Thu Mar 12 15:48:29 2026 +0300

    try to remove redundant alias in expression rewriter and select (#20867)
    
    ## Which issue does this PR close?
    Not closes
    
    ## Rationale for this change
    In
    https://github.com/apache/datafusion/pull/20780#discussion_r2911482011
    @alamb mentioned whether we can remove redundant alias of `count(*) AS
    count(*)` to `count(*)` and I tried to give this a go.
    
    
    ### I'm not sure about the implications at the moment it would be great
    to have input on this PR
    
    ## What changes are included in this PR?
    Main changes are in:
    - order_by.rs: match only top level expressions instead of recursively
    searching sub expressions (otherwise we may match wrong expressions)
    - select.rs: strip alias before comparing otherwise we dont use existing
    alias at all
    
    ## Are these changes tested?
    I've added some tests for alias. Existing tests and plan outputs changed
    as well you can see in the PR.
    
    ## Are there any user-facing changes?
    Plans will change but not sure if it has impact
---
 datafusion/core/tests/dataframe/mod.rs            |   2 +-
 datafusion/expr/src/expr_rewriter/order_by.rs     | 240 +++++++++++++++++++---
 datafusion/sql/src/select.rs                      |  17 +-
 datafusion/sqllogictest/test_files/clickbench.slt |   8 +-
 datafusion/sqllogictest/test_files/order.slt      |  48 +++++
 5 files changed, 279 insertions(+), 36 deletions(-)

diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index b1ee8b09b9..80bbde1f6b 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -3004,7 +3004,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
     
+---------------+------------------------------------------------------------------------------------+
     | plan_type     | plan                                                     
                          |
     
+---------------+------------------------------------------------------------------------------------+
-    | logical_plan  | Sort: count(*) AS count(*) ASC NULLS LAST                
                          |
+    | logical_plan  | Sort: count(*) ASC NULLS LAST                            
                          |
     |               |   Projection: t1.b, count(Int64(1)) AS count(*)          
                          |
     |               |     Aggregate: groupBy=[[t1.b]], 
aggr=[[count(Int64(1))]]                          |
     |               |       TableScan: t1 projection=[b]                       
                          |
diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs 
b/datafusion/expr/src/expr_rewriter/order_by.rs
index a897e56d27..720788113c 100644
--- a/datafusion/expr/src/expr_rewriter/order_by.rs
+++ b/datafusion/expr/src/expr_rewriter/order_by.rs
@@ -21,9 +21,7 @@ use crate::expr::Alias;
 use crate::expr_rewriter::normalize_col;
 use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort};
 
-use datafusion_common::tree_node::{
-    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
-};
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::{Column, Result};
 
 /// Rewrite sort on aggregate expressions to sort on the column of aggregate 
output
@@ -104,29 +102,27 @@ fn rewrite_in_terms_of_projection(
 
         let search_col = Expr::Column(Column::new_unqualified(name));
 
-        // look for the column named the same as this expr
-        let mut found = None;
-        for proj_expr in proj_exprs {
-            proj_expr.apply(|e| {
-                if expr_match(&search_col, e) {
-                    found = Some(e.clone());
-                    return Ok(TreeNodeRecursion::Stop);
-                }
-                Ok(TreeNodeRecursion::Continue)
-            })?;
-        }
+        // Search only top-level projection expressions for a match.
+        // We intentionally avoid a recursive search (e.g. `apply`) to
+        // prevent matching sub-expressions of composites like
+        // `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`.
+        let found = proj_exprs
+            .iter()
+            .find(|proj_expr| expr_match(&search_col, proj_expr));
 
         if let Some(found) = found {
+            let (qualifier, field_name) = found.qualified_name();
+            let col = Expr::Column(Column::new(qualifier, field_name));
             return Ok(Transformed::yes(match normalized_expr {
                 Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
-                    expr: Box::new(found),
+                    expr: Box::new(col),
                     field,
                 }),
                 Expr::TryCast(TryCast { expr: _, field }) => 
Expr::TryCast(TryCast {
-                    expr: Box::new(found),
+                    expr: Box::new(col),
                     field,
                 }),
-                _ => found,
+                _ => col,
             }));
         }
 
@@ -160,7 +156,10 @@ mod test {
 
     use super::*;
     use crate::test::function_stub::avg;
+    use crate::test::function_stub::count;
+    use crate::test::function_stub::max;
     use crate::test::function_stub::min;
+    use crate::test::function_stub::sum;
 
     #[test]
     fn rewrite_sort_cols_by_agg() {
@@ -242,17 +241,14 @@ mod test {
             TestCase {
                 desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* 
"min(t.c2)"!)"#,
                 input: sort(col("c1") + min(col("c2"))),
-                // should be "c1" not t.c1
                 expected: sort(
                     col("c1") + 
Expr::Column(Column::new_unqualified("min(t.c2)")),
                 ),
             },
             TestCase {
-                desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* 
"avg(t.c3)", aliased)"#,
+                desc: r#"avg(c3) --> "average" (column *named* "average", from 
alias)"#,
                 input: sort(avg(col("c3"))),
-                expected: sort(
-                    
Expr::Column(Column::new_unqualified("avg(t.c3)")).alias("average"),
-                ),
+                expected: sort(col("average")),
             },
         ];
 
@@ -261,6 +257,202 @@ mod test {
         }
     }
 
+    /// When an aggregate is aliased in the projection,
+    /// ORDER BY on the original aggregate expression should resolve to
+    /// a Column reference using the alias name — not leak the inner
+    /// Alias expression node or resolve to a descendant subtree.
+    #[test]
+    fn rewrite_sort_resolves_alias_to_column_ref() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
+            .unwrap()
+            .project(vec![
+                col("c1"),
+                min(col("c2")).alias("min_val"),
+                max(col("c3")).alias("max_val"),
+            ])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![
+            TestCase {
+                desc: "min(c2) with alias 'min_val' should resolve to 
col(min_val)",
+                input: sort(min(col("c2"))),
+                expected: sort(col("min_val")),
+            },
+            TestCase {
+                desc: "max(c3) with alias 'max_val' should resolve to 
col(max_val)",
+                input: sort(max(col("c3"))),
+                expected: sort(col("max_val")),
+            },
+        ];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn composite_proj_expr_containing_sort_col_as_subexpr() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))])
+            .unwrap()
+            .project(vec![
+                col("c1"),
+                (min(col("c2")) + max(col("c3"))).alias("range"),
+                min(col("c2")).alias("min_val"),
+                max(col("c3")).alias("max_val"),
+            ])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![
+            TestCase {
+                desc: "sort by min(c2) should resolve to col(min_val), not 
col(range)",
+                input: sort(min(col("c2"))),
+                expected: sort(col("min_val")),
+            },
+            TestCase {
+                desc: "sort by max(c3) should resolve to col(max_val), not 
col(range)",
+                input: sort(max(col("c3"))),
+                expected: sort(col("max_val")),
+            },
+        ];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn composite_before_standalone_should_not_shadow() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))])
+            .unwrap()
+            .project(vec![
+                col("c1"),
+                (min(col("c2")) + max(col("c2"))).alias("combined"),
+                min(col("c2")),
+            ])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![TestCase {
+            desc: "sort by min(c2) should resolve to col(min(t.c2)), not 
col(combined)",
+            input: sort(min(col("c2"))),
+            expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))),
+        }];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn duplicate_aggregate_in_multiple_proj_exprs() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![min(col("c2"))])
+            .unwrap()
+            .project(vec![
+                col("c1"),
+                min(col("c2")).alias("first_alias"),
+                min(col("c2")).alias("second_alias"),
+            ])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![TestCase {
+            desc: "sort by min(c2) with two aliases picks first_alias",
+            input: sort(min(col("c2"))),
+            expected: sort(col("first_alias")),
+        }];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn sort_agg_not_in_select_with_aliased_aggs() {
+        let plan = make_input()
+            .aggregate(
+                vec![col("c1")],
+                vec![min(col("c2")), max(col("c3")), sum(col("c3"))],
+            )
+            .unwrap()
+            .project(vec![
+                col("c1"),
+                min(col("c2")).alias("min_val"),
+                max(col("c3")).alias("max_val"),
+            ])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![TestCase {
+            desc: "sort by sum(c3) not in projection should not be rewritten",
+            input: sort(sum(col("c3"))),
+            expected: sort(sum(col("c3"))),
+        }];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn cast_on_aliased_aggregate() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![min(col("c2"))])
+            .unwrap()
+            .project(vec![col("c1"), min(col("c2")).alias("min_val")])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![
+            TestCase {
+                desc: "CAST on aliased aggregate should preserve cast and 
resolve alias",
+                input: sort(cast(min(col("c2")), DataType::Int64)),
+                expected: sort(cast(col("min_val"), DataType::Int64)),
+            },
+            TestCase {
+                desc: "TryCast on aliased aggregate should preserve try_cast 
and resolve alias",
+                input: sort(try_cast(min(col("c2")), DataType::Int64)),
+                expected: sort(try_cast(col("min_val"), DataType::Int64)),
+            },
+        ];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
+    #[test]
+    fn count_star_with_alias() {
+        let plan = make_input()
+            .aggregate(vec![col("c1")], vec![count(lit(1))])
+            .unwrap()
+            .project(vec![col("c1"), count(lit(1)).alias("cnt")])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let cases = vec![TestCase {
+            desc: "sort by count(1) should resolve to cnt alias",
+            input: sort(count(lit(1))),
+            expected: sort(col("cnt")),
+        }];
+
+        for case in cases {
+            case.run(&plan)
+        }
+    }
+
     #[test]
     fn preserve_cast() {
         let plan = make_input()
@@ -275,12 +467,12 @@ mod test {
             TestCase {
                 desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
                 input: sort(cast(col("c2"), DataType::Int64)),
-                expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
+                expected: sort(cast(col("c2"), DataType::Int64)),
             },
             TestCase {
                 desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
                 input: sort(try_cast(col("c2"), DataType::Int64)),
-                expected: sort(try_cast(col("c2").alias("c2"), 
DataType::Int64)),
+                expected: sort(try_cast(col("c2"), DataType::Int64)),
             },
         ];
 
diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs
index edf4b9ef79..7e291afa04 100644
--- a/datafusion/sql/src/select.rs
+++ b/datafusion/sql/src/select.rs
@@ -1056,13 +1056,16 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
                     .iter()
                     .find_map(|select_expr| {
                         // Only consider aliased expressions
-                        if let Expr::Alias(alias) = select_expr
-                            && alias.expr.as_ref() == &rewritten_expr
-                        {
-                            // Use the alias name
-                            return Some(Expr::Column(Column::new_unqualified(
-                                alias.name.clone(),
-                            )));
+                        if let Expr::Alias(alias) = select_expr {
+                            let rewritten_unaliased = match &rewritten_expr {
+                                Expr::Alias(a) => a.expr.as_ref(),
+                                other => other,
+                            };
+                            if alias.expr.as_ref() == rewritten_unaliased {
+                                return 
Some(Expr::Column(Column::new_unqualified(
+                                    alias.name.clone(),
+                                )));
+                            }
                         }
                         None
                     })
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt 
b/datafusion/sqllogictest/test_files/clickbench.slt
index 881e49cdeb..e14d28d5ef 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -205,7 +205,7 @@ query TT
 EXPLAIN SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 
GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC;
 ----
 logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST
+01)Sort: count(*) DESC NULLS FIRST
 02)--Projection: hits.AdvEngineID, count(Int64(1)) AS count(*)
 03)----Aggregate: groupBy=[[hits.AdvEngineID]], aggr=[[count(Int64(1))]]
 04)------SubqueryAlias: hits
@@ -431,7 +431,7 @@ query TT
 EXPLAIN SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY 
COUNT(*) DESC LIMIT 10;
 ----
 logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
 02)--Projection: hits.UserID, count(Int64(1)) AS count(*)
 03)----Aggregate: groupBy=[[hits.UserID]], aggr=[[count(Int64(1))]]
 04)------SubqueryAlias: hits
@@ -459,7 +459,7 @@ query TT
 EXPLAIN SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", 
"SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
 ----
 logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
 02)--Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*)
 03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], 
aggr=[[count(Int64(1))]]
 04)------SubqueryAlias: hits
@@ -514,7 +514,7 @@ query TT
 EXPLAIN SELECT "UserID", extract(minute FROM 
to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits 
GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
 ----
 logical_plan
-01)Sort: count(*) AS count(*) DESC NULLS FIRST, fetch=10
+01)Sort: count(*) DESC NULLS FIRST, fetch=10
 02)--Projection: hits.UserID, 
date_part(Utf8("MINUTE"),to_timestamp_seconds(hits.EventTime)) AS m, 
hits.SearchPhrase, count(Int64(1)) AS count(*)
 03)----Aggregate: groupBy=[[hits.UserID, date_part(Utf8("MINUTE"), 
to_timestamp_seconds(hits.EventTime)), hits.SearchPhrase]], 
aggr=[[count(Int64(1))]]
 04)------SubqueryAlias: hits
diff --git a/datafusion/sqllogictest/test_files/order.slt 
b/datafusion/sqllogictest/test_files/order.slt
index 7c857cae36..892a42ad61 100644
--- a/datafusion/sqllogictest/test_files/order.slt
+++ b/datafusion/sqllogictest/test_files/order.slt
@@ -471,6 +471,54 @@ select column1 from foo order by column2 % 2, column2;
 3
 5
 
+# ORDER BY aggregate expression that is aliased in SELECT
+query II
+select column1, min(column2) as min_val from foo group by column1 order by 
min(column2);
+----
+1 2
+3 4
+5 6
+
+# ORDER BY aggregate with alias, using DESC
+query II rowsort
+select column1, count(*) as cnt from foo group by column1 order by count(*) 
desc;
+----
+1 1
+3 1
+5 1
+
+# ORDER BY aggregate not in SELECT, while other aggregates in SELECT are 
aliased
+query I
+select column1 from foo group by column1 order by max(column2);
+----
+1
+3
+5
+
+# SELECT has composite expression containing the aggregate, plus standalone 
alias
+query III
+select column1, min(column2) + max(column2) as range_val, min(column2) as 
min_val from foo group by column1 order by min(column2);
+----
+1 4 2
+3 8 4
+5 12 6
+
+# ORDER BY aggregate that matches multiple aliased SELECT expressions
+query III
+select column1, min(column2) as first_min, min(column2) as second_min from foo 
group by column1 order by min(column2);
+----
+1 2 2
+3 4 4
+5 6 6
+
+# ORDER BY with CAST on aliased aggregate
+query II
+select column1, min(column2) as min_val from foo group by column1 order by 
CAST(min(column2) AS BIGINT);
+----
+1 2
+3 4
+5 6
+
 # Cleanup
 statement ok
 drop table foo;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to