This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 436ca7a7317a [SPARK-52101][SQL] Disable Inline Forcing for rCTEs
436ca7a7317a is described below
commit 436ca7a7317a6b9151e334b0bec059a58c0a4f18
Author: pavle-martinovic_data <[email protected]>
AuthorDate: Mon May 19 20:36:27 2025 +0800
[SPARK-52101][SQL] Disable Inline Forcing for rCTEs
### What changes were proposed in this pull request?
Enable DML commands to work with rCTEs by catching a With above a command
in eagerlyExecuteCommands.
### Why are the changes needed?
Being able to UPDATE, DELETE and MERGE with rCTEs.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests in MergeIntoTableSuiteBase, DeleteFromTableSuiteBase,
UpdateTableSuite.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50904 from Pajaraja/pavle-martinovic_data/cmdctes.
Authored-by: pavle-martinovic_data <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 4 +-
.../sql/catalyst/analysis/CTESubstitution.scala | 21 +++++----
.../spark/sql/execution/QueryExecution.scala | 7 ++-
.../analyzer-results/cte-recursion.sql.out | 40 ++++++++++++-----
.../resources/sql-tests/inputs/cte-recursion.sql | 2 +
.../sql-tests/results/cte-recursion.sql.out | 40 ++++++++++++-----
.../sql/connector/DeleteFromTableSuiteBase.scala | 39 +++++++++++++++++
.../sql/connector/MergeIntoTableSuiteBase.scala | 51 ++++++++++++++++++++++
.../spark/sql/connector/UpdateTableSuiteBase.scala | 42 ++++++++++++++++++
9 files changed, 211 insertions(+), 35 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index e2fed243c476..b34855be9ca2 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4660,9 +4660,9 @@
],
"sqlState" : "42836"
},
- "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED" : {
+ "RECURSIVE_CTE_WITH_LEGACY_INLINE_FLAG" : {
"message" : [
- "Recursive definitions cannot be used when CTE inlining is forced."
+ "Recursive definitions cannot be used when legacy inline flag is set to
true (spark.sql.legacy.inlineCTEInCommands=true)."
],
"sqlState" : "42836"
},
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index 372580c628bd..0b6f8ec87417 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -65,19 +65,29 @@ object CTESubstitution extends Rule[LogicalPlan] {
case p => p.children.flatMap(collectCommands)
}
val commands = collectCommands(plan)
+ val hasRecursiveCTE = plan.collectFirstWithSubqueries {
+ case UnresolvedWith(_, _, true) =>
+ true
+ }.getOrElse(false)
+
+ // If the CTE is recursive we can't inline it as it has a self reference.
val forceInline = if (commands.length == 1) {
if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) {
// The legacy behavior always inlines the CTE relations for queries in
commands.
+ if (hasRecursiveCTE) {
+ plan.failAnalysis(errorClass =
"RECURSIVE_CTE_WITH_LEGACY_INLINE_FLAG",
+ messageParameters = Map.empty)
+ }
true
} else {
// If there is only one command and it's `CTEInChildren`, we can
resolve
// CTE normally and don't need to force inline.
- !commands.head.isInstanceOf[CTEInChildren]
+ !hasRecursiveCTE && !commands.head.isInstanceOf[CTEInChildren]
}
} else if (commands.length > 1) {
// This can happen with the multi-insert statement. We should fall back
to
- // the legacy behavior.
- true
+ // the legacy behavior, unless the CTE is recursive.
+ !hasRecursiveCTE
} else {
false
}
@@ -219,11 +229,6 @@ object CTESubstitution extends Rule[LogicalPlan] {
_.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
// allowRecursion flag is set to `True` by the parser if the `RECURSIVE`
keyword is used.
case cte @ UnresolvedWith(child: LogicalPlan, relations, allowRecursion)
=>
- if (allowRecursion && forceInline) {
- cte.failAnalysis(
- errorClass = "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED",
- messageParameters = Map.empty)
- }
val tempCteDefs = ArrayBuffer.empty[CTERelationDef]
val resolvedCTERelations = if (recursiveCTERelationAncestor.isDefined)
{
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index f5e803ea9c97..9e5264d8d4f3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.{InternalRow,
QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{LazyExpression,
NameParameterizedQuery, UnsupportedOperationChecker}
import org.apache.spark.sql.catalyst.expressions.codegen.ByteCodeStats
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command,
CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan,
OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect,
ReturnAnswer, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, Command,
CommandResult, CompoundBody, CreateTableAsSelect, LogicalPlan,
OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect,
ReturnAnswer, Union, WithCTE}
import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule}
import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -164,9 +164,14 @@ class QueryExecution(
p transformDown {
case u @ Union(children, _, _) if
children.forall(_.isInstanceOf[Command]) =>
eagerlyExecute(u, "multi-commands", CommandExecutionMode.SKIP)
+ case w @ WithCTE(u @ Union(children, _, _), _) if
children.forall(_.isInstanceOf[Command]) =>
+ eagerlyExecute(w, "multi-commands", CommandExecutionMode.SKIP)
case c: Command =>
val name = commandExecutionName(c)
eagerlyExecute(c, name, CommandExecutionMode.NON_ROOT)
+ case w @ WithCTE(c: Command, _) =>
+ val name = commandExecutionName(c)
+ eagerlyExecute(w, name, CommandExecutionMode.SKIP)
}
}
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
index 1ea74be18887..4aff03883865 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
@@ -1015,18 +1015,34 @@ FROM r
INSERT INTO rt2 SELECT *
INSERT INTO rt2 SELECT *
-- !query analysis
-org.apache.spark.sql.AnalysisException
-{
- "errorClass" : "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED",
- "sqlState" : "42836",
- "queryContext" : [ {
- "objectType" : "",
- "objectName" : "",
- "startIndex" : 1,
- "stopIndex" : 160,
- "fragment" : "WITH RECURSIVE r(level) AS (\n VALUES (0)\n UNION
ALL\n SELECT level + 1 FROM r WHERE level < 9\n)\nFROM r\nINSERT INTO rt2
SELECT *\nINSERT INTO rt2 SELECT *"
- } ]
-}
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias r
+: +- Project [col1#x AS level#x]
+: +- UnionLoop xxxx
+: :- LocalRelation [col1#x]
+: +- Project [(level#x + 1) AS (level + 1)#x]
+: +- Filter (level#x < 9)
+: +- SubqueryAlias r
+: +- Project [col1#x AS level#x]
+: +- UnionLoopRef xxxx, [col1#x], false
++- Union false, false
+ :- InsertIntoHadoopFsRelationCommand file:[not included in
comparison]/{warehouse_dir}/rt2, false, CSV, [path=file:[not included in
comparison]/{warehouse_dir}/rt2], Append, `spark_catalog`.`default`.`rt2`,
org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included
in comparison]/{warehouse_dir}/rt2), [level]
+ : +- Project [level#x]
+ : +- SubqueryAlias r
+ : +- CTERelationRef xxxx, true, [level#x], false, false
+ +- InsertIntoHadoopFsRelationCommand file:[not included in
comparison]/{warehouse_dir}/rt2, false, CSV, [path=file:[not included in
comparison]/{warehouse_dir}/rt2], Append, `spark_catalog`.`default`.`rt2`,
org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included
in comparison]/{warehouse_dir}/rt2), [level]
+ +- Project [level#x]
+ +- SubqueryAlias r
+ +- CTERelationRef xxxx, true, [level#x], false, false
+
+
+-- !query
+SELECT * FROM rt2
+-- !query analysis
+Project [level#x]
++- SubqueryAlias spark_catalog.default.rt2
+ +- Relation spark_catalog.default.rt2[level#x] csv
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
index 8cb77b96e3cd..fba8861083be 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
@@ -392,6 +392,8 @@ FROM r
INSERT INTO rt2 SELECT *
INSERT INTO rt2 SELECT *;
+SELECT * FROM rt2;
+
DROP TABLE rt2;
-- multiple recursive CTEs
diff --git
a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
index ea2cfb535ccd..06f440a3f633 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
@@ -985,18 +985,34 @@ INSERT INTO rt2 SELECT *
-- !query schema
struct<>
-- !query output
-org.apache.spark.sql.AnalysisException
-{
- "errorClass" : "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED",
- "sqlState" : "42836",
- "queryContext" : [ {
- "objectType" : "",
- "objectName" : "",
- "startIndex" : 1,
- "stopIndex" : 160,
- "fragment" : "WITH RECURSIVE r(level) AS (\n VALUES (0)\n UNION
ALL\n SELECT level + 1 FROM r WHERE level < 9\n)\nFROM r\nINSERT INTO rt2
SELECT *\nINSERT INTO rt2 SELECT *"
- } ]
-}
+
+
+
+-- !query
+SELECT * FROM rt2
+-- !query schema
+struct<level:int>
+-- !query output
+0
+0
+1
+1
+2
+2
+3
+3
+4
+4
+5
+5
+6
+6
+7
+7
+8
+8
+9
+9
-- !query
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
index 585480ace725..d394c6f12c7e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala
@@ -587,6 +587,45 @@ abstract class DeleteFromTableSuiteBase extends
RowLevelOperationSuiteBase {
}
}
+
+ test("delete from table with recursive CTE") {
+ withTempView("source") {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | val INT)
+ |""".stripMargin)
+
+ append("val INT",
+ """{ "val": 1 }
+ |{ "val": 9 }
+ |{ "val": 8 }
+ |{ "val": 4 }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1),
+ Row(9),
+ Row(8),
+ Row(4)))
+
+ sql(
+ s"""WITH RECURSIVE s(val) AS (
+ | SELECT 1
+ | UNION ALL
+ | SELECT val + 1 FROM s WHERE val < 5
+ |) DELETE FROM $tableNameAsString WHERE val IN (SELECT val FROM s)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(9),
+ Row(8)))
+ }
+ }
+
private def executeDeleteWithFilters(query: String): Unit = {
val executedPlan = executeAndKeepPlan {
sql(query)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 86471ff8c456..b43424793d44 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -1675,6 +1675,57 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase {
}
}
+ test("merge into table with recursive CTE") {
+ withTempView("source") {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | val INT)
+ |""".stripMargin)
+
+ append("val INT",
+ """{ "val": 1 }
+ |{ "val": 9 }
+ |{ "val": 8 }
+ |{ "val": 4 }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1),
+ Row(9),
+ Row(8),
+ Row(4)))
+
+ sql(
+ s"""WITH RECURSIVE s(val) AS (
+ | SELECT 1
+ | UNION ALL
+ | SELECT val + 1 FROM s WHERE val < 5
+ |) MERGE INTO $tableNameAsString t
+ |USING s
+ |ON t.val = s.val
+ |WHEN MATCHED THEN
+ | UPDATE SET t.val = t.val - 1
+ |WHEN NOT MATCHED THEN
+ | INSERT (val) VALUES (-s.val)
+ |WHEN NOT MATCHED BY SOURCE THEN
+ | UPDATE SET t.val = t.val + 1
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0),
+ Row(10),
+ Row(9),
+ Row(3),
+ Row(-2),
+ Row(-3),
+ Row(-5)))
+ }
+ }
+
private def assertNoLeftBroadcastOrReplication(query: String): Unit = {
val plan = executeAndKeepPlan {
sql(query)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
index d33ad2494c3c..0c3ed5106eba 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala
@@ -620,4 +620,46 @@ abstract class UpdateTableSuiteBase extends
RowLevelOperationSuiteBase {
sqlState = "42000",
parameters = Map("walkedTypePath" -> "\ns\nn_i\n"))
}
+
+
+
+ test("update table with recursive CTE") {
+ withTempView("source") {
+ sql(
+ s"""CREATE TABLE $tableNameAsString (
+ | val INT)
+ |""".stripMargin)
+
+ append("val INT",
+ """{ "val": 1 }
+ |{ "val": 9 }
+ |{ "val": 8 }
+ |{ "val": 4 }
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1),
+ Row(9),
+ Row(8),
+ Row(4)))
+
+ sql(
+ s"""WITH RECURSIVE s(val) AS (
+ | SELECT 1
+ | UNION ALL
+ | SELECT val + 1 FROM s WHERE val < 5
+ |) UPDATE $tableNameAsString SET val = val + 1 WHERE val IN (SELECT
val FROM s)
+ |""".stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(2),
+ Row(9),
+ Row(8),
+ Row(5)))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]