This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d65ee811b0f3 [SPARK-46741][SQL] Cache Table with CTE should work when 
CTE in plan expression subquery
d65ee811b0f3 is described below

commit d65ee811b0f3616ecbd49ce77c93217a4d47ca82
Author: Angerszhuuuu <[email protected]>
AuthorDate: Tue Dec 23 13:18:59 2025 +0800

    [SPARK-46741][SQL] Cache Table with CTE should work when CTE in plan 
expression subquery
    
    ### What changes were proposed in this pull request?
    Follow comment 
https://github.com/apache/spark/pull/53333#discussion_r2629958838
    
    ### Why are the changes needed?
    Support all case
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53526 from AngersZhuuuu/SPARK-46741-FOLLOWUP.
    
    Lead-authored-by: Angerszhuuuu <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/normalizer/NormalizeCTEIds.scala  | 29 ++++++++++++++++------
 .../org/apache/spark/sql/CachedTableSuite.scala    | 27 ++++++++++++++++++++
 2 files changed, 48 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
index 1b1b526e7814..6c0bca0e1104 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
@@ -17,29 +17,42 @@
 
 package org.apache.spark.sql.catalyst.normalizer
 
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable
+
 import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, 
CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
 import org.apache.spark.sql.catalyst.rules.Rule
 
-object NormalizeCTEIds extends Rule[LogicalPlan]{
+object NormalizeCTEIds extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
     val curId = new java.util.concurrent.atomic.AtomicLong()
-    plan transformDown {
+    val cteIdToNewId = mutable.Map.empty[Long, Long]
+    applyInternal(plan, curId, cteIdToNewId)
+  }
 
+  private def applyInternal(
+      plan: LogicalPlan,
+      curId: AtomicLong,
+      cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
+    plan transformDownWithSubqueries {
       case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
-        ctas.copy(plan = apply(plan))
+        ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId))
 
       case withCTE @ WithCTE(plan, cteDefs) =>
-        val defIdToNewId = withCTE.cteDefs.map(_.id).map((_, 
curId.getAndIncrement())).toMap
-        val normalizedPlan = canonicalizeCTE(plan, defIdToNewId)
         val newCteDefs = cteDefs.map { cteDef =>
-          val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
-          cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
+          cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement())
+          val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId)
+          cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id))
         }
+        val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId)
         withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
     }
   }
 
-  def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]): 
LogicalPlan = {
+  private def canonicalizeCTE(
+      plan: LogicalPlan,
+      defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
     plan.transformDownWithSubqueries {
       // For nested WithCTE, if defIndex didn't contain the cteId,
       // means it's not current WithCTE's ref.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 880d8d72c73e..0d807aeae4d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -2633,6 +2633,33 @@ class CachedTableSuite extends QueryTest with 
SQLTestUtils
       }
       assert(inMemoryTableScan.size == 1)
       checkAnswer(df, Row(5) :: Nil)
+
+      sql(
+        """
+          |CACHE TABLE cache_subquery_cte_table
+          |WITH v AS (
+          |  SELECT c1 * c2 c3 from t1
+          |)
+          |SELECT *
+          |FROM v
+          |WHERE EXISTS (
+          |  WITH cte AS (SELECT 1 AS id)
+          |  SELECT 1
+          |  FROM cte
+          |  WHERE cte.id = v.c3
+          |)
+          |""".stripMargin)
+
+      val cteInSubquery = sql(
+        """
+          |SELECT * FROM cache_subquery_cte_table
+          |""".stripMargin)
+
+      val subqueryInMemoryTableScan = 
collect(cteInSubquery.queryExecution.executedPlan) {
+        case i: InMemoryTableScanExec => i
+      }
+      assert(subqueryInMemoryTableScan.size == 1)
+      checkAnswer(cteInSubquery, Row(1) :: Nil)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to