This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3aa8c9dbc1c0 [SPARK-53155][SQL] Global lower agggregation should not 
be replaced with a project
3aa8c9dbc1c0 is described below

commit 3aa8c9dbc1c0d4622cd62a65db510d6feac31ba3
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Thu Aug 7 08:11:51 2025 -0700

    [SPARK-53155][SQL] Global lower agggregation should not be replaced with a 
project
    
    ### What changes were proposed in this pull request?
    
    This patch fixes the optimization rule `RemoveRedundantAggregates`.
    
    ### Why are the changes needed?
    
    The optimizer rule `RemoveRedundantAggregates` removes redundant lower 
aggregation from a query plan and replace it with a project of referred 
non-aggregate expressions. However, if the removed aggregation is a global one, 
that is not correct because a project is different with a global aggregation in 
semantics.
    
    For example, if the input relation is empty, a project might be optimized 
to an empty relation, while a global aggregation will return a single row.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this fixes a user-facing bug. Previously, a global aggregation under 
another aggregation might be treated as redundant and replaced as a project 
with non-aggregation expressions. If the input relation is empty, the 
replacement is incorrect and might produce incorrect result. This patch adds a 
new unit test to show the difference.
    
    ### How was this patch tested?
    
    Unit test, manual test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51884 from viirya/fix_remove_redundant_agg.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../optimizer/RemoveRedundantAggregates.scala       |  8 +++++++-
 .../optimizer/RemoveRedundantAggregatesSuite.scala  | 21 ++++++++++++++++++++-
 .../apache/spark/sql/DataFrameAggregateSuite.scala  |  7 +++++++
 3 files changed, 34 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
index d6a4bd030c9d..b4602d0ddcc9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
@@ -54,7 +54,13 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] 
with AliasHelper {
         .map(_.toAttribute)
     ))
 
-    upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg
+    // If the lower aggregation is global, it is not redundant because a 
project with
+    // non-aggregate expressions is different with global aggregation in 
semantics.
+    // E.g., if the input relation is empty, a project might be optimized to 
an empty
+    // relation, while a global aggregation will return a single row.
+    lazy val lowerIsGlobalAgg = lower.groupingExpressions.isEmpty
+
+    upperHasNoDuplicateSensitiveAgg && upperRefsOnlyDeterministicNonAgg && 
!lowerIsGlobalAgg
   }
 
   private def isDuplicateSensitive(ae: AggregateExpression): Boolean = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
index 2af3057c0b85..40b3d36d4bfc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDAF}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, 
PythonUDAF}
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LocalRelation, 
LogicalPlan}
@@ -289,4 +289,23 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
     val originalQuery = Distinct(x.groupBy($"a", $"b")($"a", 
TrueLiteral)).analyze
     comparePlans(Optimize.execute(originalQuery), originalQuery)
   }
+
+  test("SPARK-53155: global lower aggregation should not be removed") {
+    object OptimizeNonRemovedRedundantAgg extends RuleExecutor[LogicalPlan] {
+      val batches = Batch("RemoveRedundantAggregates", FixedPoint(10),
+        PropagateEmptyRelation,
+        RemoveRedundantAggregates) :: Nil
+    }
+
+    val query = relation
+      .groupBy()(Literal(1).as("col1"), Literal(2).as("col2"), 
Literal(3).as("col3"))
+      .groupBy($"col1")(max($"col1"))
+      .analyze
+    val expected = relation
+      .groupBy()(Literal(1).as("col1"), Literal(2).as("col2"), 
Literal(3).as("col3"))
+      .groupBy($"col1")(max($"col1"))
+      .analyze
+    val optimized = OptimizeNonRemovedRedundantAgg.execute(query)
+    comparePlans(optimized, expected)
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 656c739af246..721d1c1deea9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -2594,6 +2594,13 @@ class DataFrameAggregateSuite extends QueryTest
       res,
       Row(LocalTime.of(22, 1, 0), LocalTime.of(3, 0, 0)))
   }
+
+  test("SPARK-53155: global lower aggregation should not be removed") {
+    val df = emptyTestData
+      .groupBy().agg(lit(1).as("col1"), lit(2).as("col2"), lit(3).as("col3"))
+      .groupBy($"col1").agg(max("col1"))
+    checkAnswer(df, Seq(Row(1, 1)))
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to