This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c4c0db027cfe [SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through
`Union` if there are duplicates in the project list
c4c0db027cfe is described below
commit c4c0db027cfeca8043c591a6e6bf9dbe146bc931
Author: Mihailo Timotic <[email protected]>
AuthorDate: Thu Jul 24 12:49:54 2025 +0800
[SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through `Union` if there
are duplicates in the project list
### What changes were proposed in this pull request?
Don't push `Project` through `Union` if there are duplicates in the project
list.
### Why are the changes needed?
This is fixing a change made in https://github.com/apache/spark/pull/51376.
Pushing down the `Project` and deduplicating its output with aliases changes
the output attribute ID and may cause issues to the down stream operators.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added a new test case.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51628 from mihailotim-db/mihailotim-db/push_through_project.
Authored-by: Mihailo Timotic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 61 +++++++++-------------
.../sql/catalyst/optimizer/SetOperationSuite.scala | 22 ++++++++
2 files changed, 46 insertions(+), 37 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2ae507c831f1..7d49d7cd8732 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import java.util.HashSet
-
import scala.collection.mutable
import org.apache.spark.SparkException
@@ -897,24 +895,6 @@ object LimitPushDown extends Rule[LogicalPlan] {
*/
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
- /**
- * When pushing a [[Project]] through [[Union]] we need to maintain the
invariant that [[Union]]
- * children must have unique [[ExprId]]s per branch. We can safely
deduplicate [[ExprId]]s
- * without updating any references because those [[ExprId]]s will simply
remain unused.
- * For example, in a `Project(col1#1, col#1)` we will alias the second
`col1` and get
- * `Project(col1#1, col1 as col1#2)`. We don't need to update any references
to `col1#1` we
- * aliased because `col1#1` still exists in [[Project]] output.
- */
- private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
- val existingExprIds = new HashSet[ExprId]
- projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
- Alias(attr, attr.name)()
- } else {
- existingExprIds.add(attr.exprId)
- attr
- })
- }
-
/**
* Maps Attributes from the left side to the corresponding Attribute on the
right side.
*/
@@ -942,16 +922,20 @@ object PushProjectionThroughUnion extends
Rule[LogicalPlan] {
result.asInstanceOf[A]
}
+ /**
+ * If [[SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED]] is
true, [[Project]] can
+ * only be pushed down if there are no duplicate [[ExprId]]s in the project
list.
+ */
+ def canPushProjectionThroughUnion(project: Project): Boolean = {
+ !conf.unionIsResolvedWhenDuplicatesPerChildResolved ||
+ project.outputSet.size == project.projectList.size
+ }
+
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union):
Seq[LogicalPlan] = {
- val deduplicatedProjectList = if
(conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
- deduplicateProjectList(projectList)
- } else {
- projectList
- }
- val newFirstChild = Project(deduplicatedProjectList, u.children.head)
+ val newFirstChild = Project(projectList, u.children.head)
val newOtherChildren = u.children.tail.map { child =>
val rewrites = buildRewrites(u.children.head, child)
- Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
+ Project(projectList.map(pushToRight(_, rewrites)), child)
}
newFirstChild +: newOtherChildren
}
@@ -960,8 +944,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan]
{
_.containsAllPatterns(UNION, PROJECT)) {
// Push down deterministic projection through UNION ALL
- case Project(projectList, u: Union)
- if projectList.forall(_.deterministic) && u.children.nonEmpty =>
+ case project @ Project(projectList, u: Union)
+ if projectList.forall(_.deterministic) && u.children.nonEmpty &&
+ canPushProjectionThroughUnion(project) =>
u.copy(children = pushProjectionThroughUnion(projectList, u))
}
}
@@ -1586,7 +1571,7 @@ object InferFiltersFromConstraints extends
Rule[LogicalPlan]
*/
object CombineUnions extends Rule[LogicalPlan] {
import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
- import PushProjectionThroughUnion.pushProjectionThroughUnion
+ import PushProjectionThroughUnion.{canPushProjectionThroughUnion,
pushProjectionThroughUnion}
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
@@ -1631,17 +1616,19 @@ object CombineUnions extends Rule[LogicalPlan] {
stack.pushAll(children.reverse)
// Push down projection through Union and then push pushed plan to
Stack if
// there is a Project.
- case Project(projectList, Distinct(u @ Union(children, byName,
allowMissingCol)))
+ case project @ Project(projectList, Distinct(u @ Union(children,
byName, allowMissingCol)))
if projectList.forall(_.deterministic) && children.nonEmpty &&
- flattenDistinct && byName == topByName && allowMissingCol ==
topAllowMissingCol =>
+ flattenDistinct && byName == topByName && allowMissingCol ==
topAllowMissingCol &&
+ canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
- case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
+ case project @ Project(projectList, Deduplicate(keys: Seq[Attribute],
u: Union))
if projectList.forall(_.deterministic) && flattenDistinct &&
u.byName == topByName &&
- u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) ==
u.outputSet =>
+ u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) ==
u.outputSet &&
+ canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
- case Project(projectList, u @ Union(children, byName, allowMissingCol))
- if projectList.forall(_.deterministic) && children.nonEmpty &&
- byName == topByName && allowMissingCol == topAllowMissingCol =>
+ case project @ Project(projectList, u @ Union(children, byName,
allowMissingCol))
+ if projectList.forall(_.deterministic) && children.nonEmpty &&
byName == topByName &&
+ allowMissingCol == topAllowMissingCol &&
canPushProjectionThroughUnion(project) =>
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
case child =>
flattened += child
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index b2b1f9014989..d3aa1e0cd37c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And,
GreaterThan, GreaterThanO
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DecimalType}
class SetOperationSuite extends PlanTest {
@@ -313,6 +314,27 @@ class SetOperationSuite extends PlanTest {
comparePlans(unionOptimized, unionCorrectAnswer)
}
+ test("SPARK-52686: no pushdown if project has duplicate expression IDs") {
+ val unionQuery = testUnion.select($"a", $"a")
+ val unionCorrectAnswerWithConfOn = unionQuery.analyze
+ val unionCorrectAnswerWithConfOff = Union(
+ testRelation.select($"a", $"a").analyze ::
+ testRelation2.select($"d", $"d").analyze ::
+ testRelation3.select($"g", $"g").analyze ::
+ Nil
+ )
+
+
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key ->
"true") {
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ comparePlans(unionOptimized, unionCorrectAnswerWithConfOn)
+ }
+
+
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key ->
"false") {
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ comparePlans(unionOptimized, unionCorrectAnswerWithConfOff)
+ }
+ }
+
test("CombineUnions only flatten the unions with same byName and
allowMissingCol") {
val union1 = Union(testRelation :: testRelation :: Nil, true, false)
val union2 = Union(testRelation :: testRelation :: Nil, true, true)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]