This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7b499191d28d [SPARK-52427][SQL] Normalize aggregate expression list
covered by a Project
7b499191d28d is described below
commit 7b499191d28de318fe1c0d359fa6a65706e2ee72
Author: Vladimir Golubev <[email protected]>
AuthorDate: Fri Jun 13 16:22:55 2025 -0700
[SPARK-52427][SQL] Normalize aggregate expression list covered by a Project
### What changes were proposed in this pull request?
Normalize aggregate expression list covered by a `Project`. Their order
does not make any difference, since the projection on top enforces that order.
### Why are the changes needed?
To make sure that single-pass and fixed-point Analyzer plans are compatible.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New test in `NormalizePlanSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51128 from
vladimirg-db/vladimir-golubev_data/nomalize-aggregate-list-under-project.
Authored-by: Vladimir Golubev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/plans/NormalizePlan.scala | 48 ++++++++++++---
.../sql/catalyst/plans/NormalizePlanSuite.scala | 69 ++++++++++++++++++----
2 files changed, 98 insertions(+), 19 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
index 643dc8dc8746..41cd5c3dbd86 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
@@ -153,18 +153,40 @@ object NormalizePlan extends PredicateHelper {
.getTagValue(DeduplicateRelations.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION)
.isDefined =>
project.child
+
case aggregate @ Aggregate(_, _, innerProject: Project, _) =>
- val newInnerProject = Project(
- innerProject.projectList.sortBy(_.name),
- innerProject.child
- )
- aggregate.copy(child = newInnerProject)
+ aggregate.copy(child = normalizeProjectListOrder(innerProject))
+
case project @ Project(_, innerProject: Project) =>
- val newInnerProject = Project(
- innerProject.projectList.sortBy(_.name),
- innerProject.child
+ project.copy(child = normalizeProjectListOrder(innerProject))
+
+ case project @ Project(_, innerAggregate: Aggregate) =>
+ project.copy(child = normalizeAggregateListOrder(innerAggregate))
+
+ /**
+ * ORDER BY covered by an output-retaining project on top of GROUP BY
+ */
+ case project @ Project(_, sort @ Sort(_, _, innerAggregate: Aggregate,
_)) =>
+ project.copy(child = sort.copy(child =
normalizeAggregateListOrder(innerAggregate)))
+
+ /**
+ * HAVING covered by an output-retaining project on top of GROUP BY
+ */
+ case project @ Project(_, filter @ Filter(_, innerAggregate: Aggregate))
=>
+ project.copy(child = filter.copy(child =
normalizeAggregateListOrder(innerAggregate)))
+
+ /**
+ * HAVING ... ORDER BY covered by an output-retaining project on top of
GROUP BY
+ */
+ case project @ Project(
+ _,
+ sort @ Sort(_, _, filter @ Filter(_, innerAggregate: Aggregate), _)
+ ) =>
+ project.copy(
+ child =
+ sort.copy(child = filter.copy(child =
normalizeAggregateListOrder(innerAggregate)))
)
- project.copy(child = newInnerProject)
+
case c: KeepAnalyzedQuery => c.storeAnalyzedQuery()
case localRelation: LocalRelation if !localRelation.data.isEmpty =>
/**
@@ -200,6 +222,14 @@ object NormalizePlan extends PredicateHelper {
case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() =>
GreaterThanOrEqual(r, l)
case _ => condition // Don't reorder.
}
+
+ private def normalizeProjectListOrder(project: Project): Project = {
+ project.copy(projectList = project.projectList.sortBy(_.name))
+ }
+
+ private def normalizeAggregateListOrder(aggregate: Aggregate): Aggregate = {
+ aggregate.copy(aggregateExpressions =
aggregate.aggregateExpressions.sortBy(_.name))
+ }
}
class CteIdNormalizer {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala
index 87d59be5aa37..d610fb828f86 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/NormalizePlanSuite.scala
@@ -23,16 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{
- AssertTrue,
- Cast,
- CommonExpressionDef,
- CommonExpressionId,
- CommonExpressionRef,
- If,
- Literal,
- TimeZoneAwareExpression
-}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.BooleanType
@@ -68,6 +59,64 @@ class NormalizePlanSuite extends SparkFunSuite with
SQLConfHelper {
assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
}
+ test("Normalize ordering in an aggregate list of an inner Aggregate under
Project") {
+ val baselinePlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col1", $"col2")
+ .select($"col1")
+ val testPlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col2", $"col1")
+ .select($"col1")
+
+ assert(baselinePlan != testPlan)
+ assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
+ }
+
+ test("Normalize ordering in an aggregate list of an inner Aggregate under
Project and Filter") {
+ val baselinePlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col1", $"col2")
+ .where($"col1" === 1)
+ .select($"col1")
+ val testPlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col2", $"col1")
+ .where($"col1" === 1)
+ .select($"col1")
+
+ assert(baselinePlan != testPlan)
+ assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
+ }
+
+ test("Normalize ordering in an aggregate list of an inner Aggregate under
Project and Sort") {
+ val baselinePlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col1", $"col2")
+ .orderBy(SortOrder($"col1", Ascending))
+ .select($"col1")
+ val testPlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col2", $"col1")
+ .orderBy(SortOrder($"col1", Ascending))
+ .select($"col1")
+
+ assert(baselinePlan != testPlan)
+ assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
+ }
+
+ test(
+ "Normalize ordering in an aggregate list of an inner Aggregate under
Project Sort and Filter"
+ ) {
+ val baselinePlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col1", $"col2")
+ .where($"col1" === 1)
+ .orderBy(SortOrder($"col1", Ascending))
+ .select($"col1")
+ val testPlan = LocalRelation($"col1".int, $"col2".string)
+ .groupBy($"col1", $"col2")($"col2", $"col1")
+ .where($"col1" === 1)
+ .orderBy(SortOrder($"col1", Ascending))
+ .select($"col1")
+
+ assert(baselinePlan != testPlan)
+ assert(NormalizePlan(baselinePlan) == NormalizePlan(testPlan))
+ }
+
test("Normalize InheritAnalysisRules expressions") {
val castWithoutTimezone =
Cast(child = Literal(1), dataType = BooleanType, ansiEnabled =
conf.ansiEnabled)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]