This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8f696791507d [SPARK-46741][SQL] Cache Table with CTE won't work
8f696791507d is described below
commit 8f696791507d85a684d2e12f579c81e454b87234
Author: Angerszhuuuu <[email protected]>
AuthorDate: Thu Dec 18 15:48:32 2025 +0800
[SPARK-46741][SQL] Cache Table with CTE won't work
### What changes were proposed in this pull request?
Reopen https://github.com/apache/spark/pull/44767
Cache Table with CTE won't work, there are two reasons
1. In the current code CTE in CacheTableAsSelect will be inlined
2. CTERelation Ref and Def didn't handle the CTEId doCanonicalize issue
Cause the current case can't be matched.
### Why are the changes needed?
Fix Bug
### Does this PR introduce _any_ user-facing change?
Yea, Cache table with CTE can work after this pr
For added `cache.sql` final query
`EXPLAIN EXTENDED SELECT * FROM cache_nested_cte_table;`
Before this pr, the plan as below, cache won't work.
<img width="1067" height="584" alt="截屏2025-12-05 11 22 05"
src="https://github.com/user-attachments/assets/045df794-38e2-47d9-848e-cfc3c7525671"
/>
After this pr
<img width="1279" height="824" alt="截屏2025-12-05 11 32 38"
src="https://github.com/user-attachments/assets/86f5ab33-67c6-44d0-b5d8-4bec51a2d5b7"
/>
### How was this patch tested?
Added UT
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53333 from AngersZhuuuu/SPARK-46741.
Authored-by: Angerszhuuuu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/normalizer/NormalizeCTEIds.scala | 54 ++++++++++++++++++++++
.../sql/catalyst/plans/logical/v2Commands.scala | 7 ++-
.../sql/internal/BaseSessionStateBuilder.scala | 2 +
.../org/apache/spark/sql/CachedTableSuite.scala | 38 +++++++++++++++
4 files changed, 100 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
new file mode 100644
index 000000000000..1b1b526e7814
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.normalizer
+
+import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect,
CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+object NormalizeCTEIds extends Rule[LogicalPlan]{
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ val curId = new java.util.concurrent.atomic.AtomicLong()
+ plan transformDown {
+
+ case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
+ ctas.copy(plan = apply(plan))
+
+ case withCTE @ WithCTE(plan, cteDefs) =>
+ val defIdToNewId = withCTE.cteDefs.map(_.id).map((_,
curId.getAndIncrement())).toMap
+ val normalizedPlan = canonicalizeCTE(plan, defIdToNewId)
+ val newCteDefs = cteDefs.map { cteDef =>
+ val normalizedCteDef = canonicalizeCTE(cteDef.child, defIdToNewId)
+ cteDef.copy(child = normalizedCteDef, id = defIdToNewId(cteDef.id))
+ }
+ withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
+ }
+ }
+
+ def canonicalizeCTE(plan: LogicalPlan, defIdToNewId: Map[Long, Long]):
LogicalPlan = {
+ plan.transformDownWithSubqueries {
+ // For nested WithCTE, if defIndex didn't contain the cteId,
+ // means it's not current WithCTE's ref.
+ case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
+ ref.copy(cteId = defIdToNewId(ref.cteId))
+ case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) =>
+ unionLoop.copy(id = defIdToNewId(unionLoop.id))
+ case unionLoopRef: UnionLoopRef if
defIdToNewId.contains(unionLoopRef.loopId) =>
+ unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 72274ee9bf17..fab64d771093 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -1742,7 +1742,8 @@ case class CacheTableAsSelect(
isLazy: Boolean,
options: Map[String, String],
isAnalyzed: Boolean = false,
- referredTempFunctions: Seq[String] = Seq.empty) extends
AnalysisOnlyCommand {
+ referredTempFunctions: Seq[String] = Seq.empty)
+ extends AnalysisOnlyCommand with CTEInChildren {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): CacheTableAsSelect = {
assert(!isAnalyzed)
@@ -1757,6 +1758,10 @@ case class CacheTableAsSelect(
// Collect the referred temporary functions from AnalysisContext
referredTempFunctions = ac.referredTempFunctionNames.toSeq)
}
+
+ override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = {
+ copy(plan = WithCTE(plan, cteDefs))
+ }
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 24bf618ee861..7e3a6b9dbb7e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTr
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder,
SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Expression,
ExtractSemiStructuredFields}
+import org.apache.spark.sql.catalyst.normalizer.NormalizeCTEIds
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -404,6 +405,7 @@ abstract class BaseSessionStateBuilder(
}
protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
+ NormalizeCTEIds +:
extensions.buildPlanNormalizationRules(session)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 12d26c4e195f..880d8d72c73e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -2598,6 +2598,44 @@ class CachedTableSuite extends QueryTest with
SQLTestUtils
}
}
+ test("SPARK-46741: Cache Table with CTE should work") {
+ withTempView("t1", "t2") {
+ sql(
+ """
+ |CREATE TEMPORARY VIEW t1
+ |AS
+ |SELECT * FROM VALUES (0, 0), (1, 1), (2, 2) AS t(c1, c2)
+ |""".stripMargin)
+ sql(
+ """
+ |CREATE TEMPORARY VIEW t2 AS
+ |WITH v as (
+ | SELECT c1 + c1 c3 FROM t1
+ |)
+ |SELECT SUM(c3) s FROM v
+ |""".stripMargin)
+ sql(
+ """
+ |CACHE TABLE cache_nested_cte_table
+ |WITH
+ |v AS (
+ | SELECT c1 * c2 c3 from t1
+ |)
+ |SELECT SUM(c3) FROM v
+ |EXCEPT
+ |SELECT s FROM t2
+ |""".stripMargin)
+
+ val df = sql("SELECT * FROM cache_nested_cte_table")
+
+ val inMemoryTableScan = collect(df.queryExecution.executedPlan) {
+ case i: InMemoryTableScanExec => i
+ }
+ assert(inMemoryTableScan.size == 1)
+ checkAnswer(df, Row(5) :: Nil)
+ }
+ }
+
private def cacheManager = spark.sharedState.cacheManager
private def pinTable(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]