This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c4c0db027cfe [SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through 
`Union` if there are duplicates in the project list
c4c0db027cfe is described below

commit c4c0db027cfeca8043c591a6e6bf9dbe146bc931
Author: Mihailo Timotic <mihailo.timo...@databricks.com>
AuthorDate: Thu Jul 24 12:49:54 2025 +0800

    [SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through `Union` if there 
are duplicates in the project list
    
    ### What changes were proposed in this pull request?
    Don't push `Project` through `Union` if there are duplicates in the project 
list.
    
    ### Why are the changes needed?
    This is fixing a change made in https://github.com/apache/spark/pull/51376. 
Pushing down the `Project` and deduplicating its output with aliases changes 
the output attribute ID and may cause issues to the down stream operators.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a new test case.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #51628 from mihailotim-db/mihailotim-db/push_through_project.
    
    Authored-by: Mihailo Timotic <mihailo.timo...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 61 +++++++++-------------
 .../sql/catalyst/optimizer/SetOperationSuite.scala | 22 ++++++++
 2 files changed, 46 insertions(+), 37 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2ae507c831f1..7d49d7cd8732 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import java.util.HashSet
-
 import scala.collection.mutable
 
 import org.apache.spark.SparkException
@@ -897,24 +895,6 @@ object LimitPushDown extends Rule[LogicalPlan] {
  */
 object PushProjectionThroughUnion extends Rule[LogicalPlan] {
 
-  /**
-   * When pushing a [[Project]] through [[Union]] we need to maintain the 
invariant that [[Union]]
-   * children must have unique [[ExprId]]s per branch. We can safely 
deduplicate [[ExprId]]s
-   * without updating any references because those [[ExprId]]s will simply 
remain unused.
-   * For example, in a `Project(col1#1, col#1)` we will alias the second 
`col1` and get
-   * `Project(col1#1, col1 as col1#2)`. We don't need to update any references 
to `col1#1` we
-   * aliased because `col1#1` still exists in [[Project]] output.
-   */
-  private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
-    val existingExprIds = new HashSet[ExprId]
-    projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
-      Alias(attr, attr.name)()
-    } else {
-      existingExprIds.add(attr.exprId)
-      attr
-    })
-  }
-
   /**
    * Maps Attributes from the left side to the corresponding Attribute on the 
right side.
    */
@@ -942,16 +922,20 @@ object PushProjectionThroughUnion extends 
Rule[LogicalPlan] {
     result.asInstanceOf[A]
   }
 
+  /**
+   * If [[SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED]] is 
true, [[Project]] can
+   * only be pushed down if there are no duplicate [[ExprId]]s in the project 
list.
+   */
+  def canPushProjectionThroughUnion(project: Project): Boolean = {
+    !conf.unionIsResolvedWhenDuplicatesPerChildResolved ||
+    project.outputSet.size == project.projectList.size
+  }
+
   def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): 
Seq[LogicalPlan] = {
-    val deduplicatedProjectList = if 
(conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
-      deduplicateProjectList(projectList)
-    } else {
-      projectList
-    }
-    val newFirstChild = Project(deduplicatedProjectList, u.children.head)
+    val newFirstChild = Project(projectList, u.children.head)
     val newOtherChildren = u.children.tail.map { child =>
       val rewrites = buildRewrites(u.children.head, child)
-      Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
+      Project(projectList.map(pushToRight(_, rewrites)), child)
     }
     newFirstChild +: newOtherChildren
   }
@@ -960,8 +944,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] 
{
     _.containsAllPatterns(UNION, PROJECT)) {
 
     // Push down deterministic projection through UNION ALL
-    case Project(projectList, u: Union)
-        if projectList.forall(_.deterministic) && u.children.nonEmpty =>
+    case project @ Project(projectList, u: Union)
+      if projectList.forall(_.deterministic) && u.children.nonEmpty &&
+        canPushProjectionThroughUnion(project) =>
       u.copy(children = pushProjectionThroughUnion(projectList, u))
   }
 }
@@ -1586,7 +1571,7 @@ object InferFiltersFromConstraints extends 
Rule[LogicalPlan]
  */
 object CombineUnions extends Rule[LogicalPlan] {
   import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
-  import PushProjectionThroughUnion.pushProjectionThroughUnion
+  import PushProjectionThroughUnion.{canPushProjectionThroughUnion, 
pushProjectionThroughUnion}
 
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
     _.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
@@ -1631,17 +1616,19 @@ object CombineUnions extends Rule[LogicalPlan] {
           stack.pushAll(children.reverse)
         // Push down projection through Union and then push pushed plan to 
Stack if
         // there is a Project.
-        case Project(projectList, Distinct(u @ Union(children, byName, 
allowMissingCol)))
+        case project @ Project(projectList, Distinct(u @ Union(children, 
byName, allowMissingCol)))
             if projectList.forall(_.deterministic) && children.nonEmpty &&
-              flattenDistinct && byName == topByName && allowMissingCol == 
topAllowMissingCol =>
+              flattenDistinct && byName == topByName && allowMissingCol == 
topAllowMissingCol &&
+              canPushProjectionThroughUnion(project) =>
           stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
-        case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
+        case project @ Project(projectList, Deduplicate(keys: Seq[Attribute], 
u: Union))
             if projectList.forall(_.deterministic) && flattenDistinct && 
u.byName == topByName &&
-              u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == 
u.outputSet =>
+              u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == 
u.outputSet &&
+              canPushProjectionThroughUnion(project) =>
           stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
-        case Project(projectList, u @ Union(children, byName, allowMissingCol))
-            if projectList.forall(_.deterministic) && children.nonEmpty &&
-              byName == topByName && allowMissingCol == topAllowMissingCol =>
+        case project @ Project(projectList, u @ Union(children, byName, 
allowMissingCol))
+            if projectList.forall(_.deterministic) && children.nonEmpty && 
byName == topByName &&
+              allowMissingCol == topAllowMissingCol && 
canPushProjectionThroughUnion(project) =>
           stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
         case child =>
           flattened += child
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index b2b1f9014989..d3aa1e0cd37c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, 
GreaterThan, GreaterThanO
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{BooleanType, DecimalType}
 
 class SetOperationSuite extends PlanTest {
@@ -313,6 +314,27 @@ class SetOperationSuite extends PlanTest {
     comparePlans(unionOptimized, unionCorrectAnswer)
   }
 
+  test("SPARK-52686: no pushdown if project has duplicate expression IDs") {
+    val unionQuery = testUnion.select($"a", $"a")
+    val unionCorrectAnswerWithConfOn = unionQuery.analyze
+    val unionCorrectAnswerWithConfOff = Union(
+      testRelation.select($"a", $"a").analyze ::
+      testRelation2.select($"d", $"d").analyze ::
+      testRelation3.select($"g", $"g").analyze ::
+      Nil
+    )
+
+    
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> 
"true") {
+      val unionOptimized = Optimize.execute(unionQuery.analyze)
+      comparePlans(unionOptimized, unionCorrectAnswerWithConfOn)
+    }
+
+    
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> 
"false") {
+      val unionOptimized = Optimize.execute(unionQuery.analyze)
+      comparePlans(unionOptimized, unionCorrectAnswerWithConfOff)
+    }
+  }
+
   test("CombineUnions only flatten the unions with same byName and 
allowMissingCol") {
     val union1 = Union(testRelation :: testRelation :: Nil, true, false)
     val union2 = Union(testRelation :: testRelation :: Nil, true, true)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to