This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c4c0db027cfe [SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through `Union` if there are duplicates in the project list c4c0db027cfe is described below commit c4c0db027cfeca8043c591a6e6bf9dbe146bc931 Author: Mihailo Timotic <mihailo.timo...@databricks.com> AuthorDate: Thu Jul 24 12:49:54 2025 +0800 [SPARK-52686][SQL][FOLLOWUP] Don't push `Project` through `Union` if there are duplicates in the project list ### What changes were proposed in this pull request? Don't push `Project` through `Union` if there are duplicates in the project list. ### Why are the changes needed? This is fixing a change made in https://github.com/apache/spark/pull/51376. Pushing down the `Project` and deduplicating its output with aliases changes the output attribute ID and may cause issues to the down stream operators. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a new test case. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51628 from mihailotim-db/mihailotim-db/push_through_project. Authored-by: Mihailo Timotic <mihailo.timo...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 61 +++++++++------------- .../sql/catalyst/optimizer/SetOperationSuite.scala | 22 ++++++++ 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2ae507c831f1..7d49d7cd8732 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import java.util.HashSet - import scala.collection.mutable import org.apache.spark.SparkException @@ -897,24 +895,6 @@ object LimitPushDown extends Rule[LogicalPlan] { */ object PushProjectionThroughUnion extends Rule[LogicalPlan] { - /** - * When pushing a [[Project]] through [[Union]] we need to maintain the invariant that [[Union]] - * children must have unique [[ExprId]]s per branch. We can safely deduplicate [[ExprId]]s - * without updating any references because those [[ExprId]]s will simply remain unused. - * For example, in a `Project(col1#1, col#1)` we will alias the second `col1` and get - * `Project(col1#1, col1 as col1#2)`. We don't need to update any references to `col1#1` we - * aliased because `col1#1` still exists in [[Project]] output. - */ - private def deduplicateProjectList(projectList: Seq[NamedExpression]) = { - val existingExprIds = new HashSet[ExprId] - projectList.map(attr => if (existingExprIds.contains(attr.exprId)) { - Alias(attr, attr.name)() - } else { - existingExprIds.add(attr.exprId) - attr - }) - } - /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ @@ -942,16 +922,20 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * If [[SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED]] is true, [[Project]] can + * only be pushed down if there are no duplicate [[ExprId]]s in the project list. + */ + def canPushProjectionThroughUnion(project: Project): Boolean = { + !conf.unionIsResolvedWhenDuplicatesPerChildResolved || + project.outputSet.size == project.projectList.size + } + def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = { - val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) { - deduplicateProjectList(projectList) - } else { - projectList - } - val newFirstChild = Project(deduplicatedProjectList, u.children.head) + val newFirstChild = Project(projectList, u.children.head) val newOtherChildren = u.children.tail.map { child => val rewrites = buildRewrites(u.children.head, child) - Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child) + Project(projectList.map(pushToRight(_, rewrites)), child) } newFirstChild +: newOtherChildren } @@ -960,8 +944,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] { _.containsAllPatterns(UNION, PROJECT)) { // Push down deterministic projection through UNION ALL - case Project(projectList, u: Union) - if projectList.forall(_.deterministic) && u.children.nonEmpty => + case project @ Project(projectList, u: Union) + if projectList.forall(_.deterministic) && u.children.nonEmpty && + canPushProjectionThroughUnion(project) => u.copy(children = pushProjectionThroughUnion(projectList, u)) } } @@ -1586,7 +1571,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] */ object CombineUnions extends Rule[LogicalPlan] { import CollapseProject.{buildCleanedProjectList, canCollapseExpressions} - import PushProjectionThroughUnion.pushProjectionThroughUnion + import PushProjectionThroughUnion.{canPushProjectionThroughUnion, pushProjectionThroughUnion} def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( _.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) { @@ -1631,17 +1616,19 @@ object CombineUnions extends Rule[LogicalPlan] { stack.pushAll(children.reverse) // Push down projection through Union and then push pushed plan to Stack if // there is a Project. - case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol))) + case project @ Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol))) if projectList.forall(_.deterministic) && children.nonEmpty && - flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol => + flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol && + canPushProjectionThroughUnion(project) => stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) - case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union)) + case project @ Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union)) if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName && - u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet => + u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet && + canPushProjectionThroughUnion(project) => stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) - case Project(projectList, u @ Union(children, byName, allowMissingCol)) - if projectList.forall(_.deterministic) && children.nonEmpty && - byName == topByName && allowMissingCol == topAllowMissingCol => + case project @ Project(projectList, u @ Union(children, byName, allowMissingCol)) + if projectList.forall(_.deterministic) && children.nonEmpty && byName == topByName && + allowMissingCol == topAllowMissingCol && canPushProjectionThroughUnion(project) => stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse) case child => flattened += child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index b2b1f9014989..d3aa1e0cd37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DecimalType} class SetOperationSuite extends PlanTest { @@ -313,6 +314,27 @@ class SetOperationSuite extends PlanTest { comparePlans(unionOptimized, unionCorrectAnswer) } + test("SPARK-52686: no pushdown if project has duplicate expression IDs") { + val unionQuery = testUnion.select($"a", $"a") + val unionCorrectAnswerWithConfOn = unionQuery.analyze + val unionCorrectAnswerWithConfOff = Union( + testRelation.select($"a", $"a").analyze :: + testRelation2.select($"d", $"d").analyze :: + testRelation3.select($"g", $"g").analyze :: + Nil + ) + + withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> "true") { + val unionOptimized = Optimize.execute(unionQuery.analyze) + comparePlans(unionOptimized, unionCorrectAnswerWithConfOn) + } + + withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> "false") { + val unionOptimized = Optimize.execute(unionQuery.analyze) + comparePlans(unionOptimized, unionCorrectAnswerWithConfOff) + } + } + test("CombineUnions only flatten the unions with same byName and allowMissingCol") { val union1 = Union(testRelation :: testRelation :: Nil, true, false) val union2 = Union(testRelation :: testRelation :: Nil, true, true) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org