This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 76c9516417d1 [SPARK-54835][SQL] Avoid unnecessary temp QueryExecution
for nested command execution
76c9516417d1 is described below
commit 76c9516417d1886fd0378247837eed8fff6cec6a
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri Dec 26 16:39:34 2025 +0800
[SPARK-54835][SQL] Avoid unnecessary temp QueryExecution for nested command
execution
### What changes were proposed in this pull request?
This PR is a small refactor. In DS v2 CRAS/RTAS command, we run a nested
`AppendData`/`OverwriteByExpression` command by creating a `QueryExecution`.
This `QueryExecution` will create another temp `QueryExecution` to eagerly
execute commands. This PR avoids the unnecessary temp `QueryExecution` by using
`CommandExecutionMode.SKIP` to create `QueryExecution`.
### Why are the changes needed?
Remove useless temp `QueryExecution` objects.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
### Was this patch authored or co-authored using generative AI tooling?
cursor 2.2.43
Closes #53596 from cloud-fan/command.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../spark/sql/execution/QueryExecution.scala | 33 +++++++++++++++-----
.../sql/execution/datasources/DataSource.scala | 4 +--
.../datasources/v2/WriteToDataSourceV2Exec.scala | 7 +++--
.../sql/connector/DataSourceV2DataFrameSuite.scala | 35 ++++++++++++++++++++++
4 files changed, 68 insertions(+), 11 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 3e0aef962e71..652cf5aa7c3e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -178,13 +178,8 @@ class QueryExecution(
// with the rest of processing of the root plan being just outputting
command results,
// for eagerly executed commands we mark this place as beginning of
execution.
tracker.setReadyForExecution()
- val qe = new QueryExecution(sparkSession, p, mode = mode,
- shuffleCleanupMode = shuffleCleanupMode, refreshPhaseEnabled =
refreshPhaseEnabled)
- val result = QueryExecution.withInternalError(s"Eagerly executed $name
failed.") {
- SQLExecution.withNewExecutionId(qe, Some(name)) {
- qe.executedPlan.executeCollect()
- }
- }
+ val (qe, result) = QueryExecution.runCommand(
+ sparkSession, p, name, refreshPhaseEnabled, mode,
Some(shuffleCleanupMode))
CommandResult(
qe.analyzed.output,
qe.commandExecuted,
@@ -763,4 +758,28 @@ object QueryExecution {
case _ => false
}
}
+
+ def runCommand(
+ sparkSession: SparkSession,
+ command: LogicalPlan,
+ name: String,
+ refreshPhaseEnabled: Boolean = true,
+ mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
+ shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
+ : (QueryExecution, Array[InternalRow]) = {
+ val shuffleCleanupMode = shuffleCleanupModeOpt.getOrElse(
+ determineShuffleCleanupMode(sparkSession.sessionState.conf))
+ val qe = new QueryExecution(
+ sparkSession,
+ command,
+ mode = mode,
+ shuffleCleanupMode = shuffleCleanupMode,
+ refreshPhaseEnabled = refreshPhaseEnabled)
+ val result = QueryExecution.withInternalError(s"Executed $name failed.") {
+ SQLExecution.withNewExecutionId(qe, Some(name)) {
+ qe.executedPlan.executeCollect()
+ }
+ }
+ (qe, result)
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index da3204eb221b..35588df11bfc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -39,6 +39,7 @@ import
org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.connector.catalog.TableProvider
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
+import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
@@ -531,8 +532,7 @@ case class DataSource(
disallowWritingIntervals(
outputColumns.toStructType.asNullable, format.toString,
forbidAnsiIntervals = false)
val cmd = planForWritingFileFormat(format, mode, data)
- val qe = sessionState(sparkSession).executePlan(cmd)
- qe.assertCommandExecuted()
+ QueryExecution.runCommand(sparkSession, cmd, "file source write")
// Replace the schema with that of the DataFrame we just wrote out to
avoid re-inferring
copy(userSpecifiedSchema =
Some(outputColumns.toStructType.asNullable)).resolveRelation()
case _ => throw SparkException.internalError(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 3e4a2f792a1c..6d874fb29e98 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -737,8 +737,11 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends
LeafV2CommandExec {
} else {
AppendData.byPosition(relation, query, writeOptions)
}
- val qe = QueryExecution.create(session, writeCommand,
refreshPhaseEnabled)
- qe.assertCommandExecuted()
+ QueryExecution.runCommand(
+ session,
+ writeCommand,
+ "inner data writing for CTAS/RTAS",
+ refreshPhaseEnabled)
DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics)
Nil
})(catchBlock = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
index c3164b3428f9..3a327fcf9863 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
@@ -2053,4 +2053,39 @@ class DataSourceV2DataFrameSuite
case _ => fail(s"can't pin $ident in $catalogName")
}
}
+
+ test("CTAS/RTAS should trigger two query executions") {
+ // CTAS/RTAS triggers 2 query executions:
+ // 1. The outer CTAS/RTAS command execution
+ // 2. The inner AppendData/OverwriteByExpression execution
+ var executionCount = 0
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit = {
+ executionCount += 1
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception:
Exception): Unit = {}
+ }
+
+ try {
+ spark.listenerManager.register(listener)
+ val t = "testcat.ns1.ns2.tbl"
+ withTable(t) {
+ // Test CTAS (CreateTableAsSelect)
+ executionCount = 0
+ sql(s"CREATE TABLE $t USING foo AS SELECT 1 as id, 'a' as data")
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(executionCount == 2,
+ s"CTAS should trigger 2 executions, got $executionCount")
+
+ // Test RTAS (ReplaceTableAsSelect)
+ executionCount = 0
+ sql(s"CREATE OR REPLACE TABLE $t USING foo AS SELECT 2 as id, 'b' as
data")
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(executionCount == 2,
+ s"RTAS should trigger 2 executions, got $executionCount")
+ }
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]