This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3fab712f69f0 [SPARK-50441][SQL] Fix parametrized identifiers not 
working when referencing CTEs
3fab712f69f0 is described below

commit 3fab712f69f0073d6e5481d43c455363431952fc
Author: Mihailo Timotic <[email protected]>
AuthorDate: Fri Nov 29 21:46:32 2024 +0800

    [SPARK-50441][SQL] Fix parametrized identifiers not working when 
referencing CTEs
    
    ### What changes were proposed in this pull request?
    Fix parametrized identifiers not working when referencing CTEs
    
    ### Why are the changes needed?
    For a query:
    
    `with t1 as (select 1) select * from identifier(:cte) using cte as "t1"`
    
    the resolution fails because `BindParameters` can't resolve parameters 
because it waits for `ResolveIdentifierClause` to resolve 
`UnresolvedWithCTERelation`, but `ResolveIdentifierClause` can't resolve 
`UnresolvedWithCTERelation` until all `NamedParameters` in the plan are 
resolved.
    
    Instead of delaying CTE resolution with `UnresolvedWithCTERelation`, we can 
remove node entirely and delay the resolution by keeping the original 
`PlanWithUnresolvedIdentifier` and moving the CTE resolution to its 
`planBuilder`.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a new test to `ParametersSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48994 from mihailotim-db/mihailotim-db/cte_identifer.
    
    Authored-by: Mihailo Timotic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  3 --
 .../sql/catalyst/analysis/CTESubstitution.scala    | 41 +++++++++++++++-------
 .../analysis/ResolveIdentifierClause.scala         | 15 ++------
 .../spark/sql/catalyst/analysis/parameters.scala   |  6 ++--
 .../spark/sql/catalyst/analysis/unresolved.scala   | 13 +------
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 -
 .../org/apache/spark/sql/ParametersSuite.scala     | 11 ++++++
 7 files changed, 47 insertions(+), 43 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3af3565220bd..089e18e3df4e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1610,9 +1610,6 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
         resolveReferencesInSort(s)
 
-      case u: UnresolvedWithCTERelations =>
-        UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), 
u.cteRelations)
-
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve 
${q.simpleString(conf.maxToStringFields)}")
         q.mapExpressions(resolveExpressionByPlanChildren(_, q, 
includeLastResort = true))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index ff0dbcd7ef15..d75e7d528d5b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -267,6 +267,25 @@ object CTESubstitution extends Rule[LogicalPlan] {
     resolvedCTERelations
   }
 
+  private def resolveWithCTERelations(
+      table: String,
+      alwaysInline: Boolean,
+      cteRelations: Seq[(String, CTERelationDef)],
+      unresolvedRelation: UnresolvedRelation): LogicalPlan = {
+    cteRelations
+      .find(r => conf.resolver(r._1, table))
+      .map {
+        case (_, d) =>
+          if (alwaysInline) {
+            d.child
+          } else {
+            // Add a `SubqueryAlias` for hint-resolving rules to match 
relation names.
+            SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, 
d.isStreaming))
+          }
+      }
+      .getOrElse(unresolvedRelation)
+  }
+
   private def substituteCTE(
       plan: LogicalPlan,
       alwaysInline: Boolean,
@@ -279,22 +298,20 @@ object CTESubstitution extends Rule[LogicalPlan] {
         throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))
 
       case u @ UnresolvedRelation(Seq(table), _, _) =>
-        cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, 
d) =>
-          if (alwaysInline) {
-            d.child
-          } else {
-            // Add a `SubqueryAlias` for hint-resolving rules to match 
relation names.
-            SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, 
d.isStreaming))
-          }
-        }.getOrElse(u)
+        resolveWithCTERelations(table, alwaysInline, cteRelations, u)
 
       case p: PlanWithUnresolvedIdentifier =>
         // We must look up CTE relations first when resolving 
`UnresolvedRelation`s,
         // but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf 
node
-        // and may produce `UnresolvedRelation` later.
-        // Here we wrap it with `UnresolvedWithCTERelations` so that we can
-        // delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` 
is resolved.
-        UnresolvedWithCTERelations(p, cteRelations)
+        // and may produce `UnresolvedRelation` later. Instead, we delay CTE 
resolution
+        // by moving it to the planBuilder of the corresponding 
`PlanWithUnresolvedIdentifier`.
+        p.copy(planBuilder = (nameParts, children) => {
+          p.planBuilder.apply(nameParts, children) match {
+            case u @ UnresolvedRelation(Seq(table), _, _) =>
+              resolveWithCTERelations(table, alwaysInline, cteRelations, u)
+            case other => other
+          }
+        })
 
       case other =>
         // This cannot be done in ResolveSubquery because ResolveSubquery does 
not know the CTE.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
index 0e1e71a658c8..2cf3c6390d5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, 
Expression}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, 
LogicalPlan, SubqueryAlias}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, 
UNRESOLVED_IDENTIFIER_WITH_CTE}
+import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
 import org.apache.spark.sql.types.StringType
 
 /**
@@ -30,18 +30,9 @@ import org.apache.spark.sql.types.StringType
 object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with 
EvalHelper {
 
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
-    _.containsAnyPattern(UNRESOLVED_IDENTIFIER, 
UNRESOLVED_IDENTIFIER_WITH_CTE)) {
+    _.containsPattern(UNRESOLVED_IDENTIFIER)) {
     case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && 
p.childrenResolved =>
       p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
-    case u @ UnresolvedWithCTERelations(p, cteRelations) =>
-      this.apply(p) match {
-        case u @ UnresolvedRelation(Seq(table), _, _) =>
-          cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case 
(_, d) =>
-            // Add a `SubqueryAlias` for hint-resolving rules to match 
relation names.
-            SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, 
d.isStreaming))
-          }.getOrElse(u)
-        case other => other
-      }
     case other =>
       
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER))
 {
         case e: ExpressionWithUnresolvedIdentifier if 
e.identifierExpr.resolved =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
index f24227abbb65..de7374776946 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/parameters.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, 
CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, 
MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, 
VariableReference}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
SupervisingCommand}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, 
PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_IDENTIFIER_WITH_CTE, 
UNRESOLVED_WITH}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, 
PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
 import org.apache.spark.sql.errors.QueryErrorsBase
 import org.apache.spark.sql.types.DataType
 
@@ -189,7 +189,7 @@ object BindParameters extends ParameterizedQueryProcessor 
with QueryErrorsBase {
       // We should wait for `CTESubstitution` to resolve CTE before binding 
parameters, as CTE
       // relations are not children of `UnresolvedWith`.
       case NameParameterizedQuery(child, argNames, argValues)
-        if !child.containsAnyPattern(UNRESOLVED_WITH, 
UNRESOLVED_IDENTIFIER_WITH_CTE) &&
+        if !child.containsPattern(UNRESOLVED_WITH) &&
           argValues.forall(_.resolved) =>
         if (argNames.length != argValues.length) {
           throw SparkException.internalError(s"The number of argument names 
${argNames.length} " +
@@ -200,7 +200,7 @@ object BindParameters extends ParameterizedQueryProcessor 
with QueryErrorsBase {
         bind(child) { case NamedParameter(name) if args.contains(name) => 
args(name) }
 
       case PosParameterizedQuery(child, args)
-        if !child.containsAnyPattern(UNRESOLVED_WITH, 
UNRESOLVED_IDENTIFIER_WITH_CTE) &&
+        if !child.containsPattern(UNRESOLVED_WITH) &&
           args.forall(_.resolved) =>
         val indexedArgs = args.zipWithIndex
         checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 7fc8aff72b81..0a73b6b85674 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, 
InternalRow, TableIden
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, 
LogicalPlan, UnaryNode}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, 
UnaryNode}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
@@ -76,17 +76,6 @@ case class PlanWithUnresolvedIdentifier(
     copy(identifierExpr, newChildren, planBuilder)
 }
 
-/**
- * A logical plan placeholder which delays CTE resolution
- * to moment when PlanWithUnresolvedIdentifier gets resolved
- */
-case class UnresolvedWithCTERelations(
-   unresolvedPlan: LogicalPlan,
-   cteRelations: Seq[(String, CTERelationDef)])
-  extends UnresolvedLeafNode {
-  final override val nodePatterns: Seq[TreePattern] = 
Seq(UNRESOLVED_IDENTIFIER_WITH_CTE)
-}
-
 /**
  * An expression placeholder that holds the identifier clause string 
expression. It will be
  * replaced by the actual expression with the evaluated identifier string.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 7435f4c52703..e95712281cb4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -154,7 +154,6 @@ object TreePattern extends Enumeration  {
   val UNRESOLVED_FUNCTION: Value = Value
   val UNRESOLVED_HINT: Value = Value
   val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
-  val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value
 
   // Unresolved Plan patterns (Alphabetically ordered)
   val UNRESOLVED_FUNC: Value = Value
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
index 791bcc91d509..2ac8ed26868a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
@@ -758,4 +758,15 @@ class ParametersSuite extends QueryTest with 
SharedSparkSession with PlanTest {
       checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1))
     }
   }
+
+  test("SPARK-50441: parameterized identifier referencing a CTE") {
+    def query(p: String): String = {
+      s"""
+         |WITH t1 AS (SELECT 1)
+         |SELECT * FROM IDENTIFIER($p)""".stripMargin
+    }
+
+    checkAnswer(spark.sql(query(":cte"), args = Map("cte" -> "t1")), Row(1))
+    checkAnswer(spark.sql(query("?"), args = Array("t1")), Row(1))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to