This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 76c9516417d1 [SPARK-54835][SQL] Avoid unnecessary temp QueryExecution 
for nested command execution
76c9516417d1 is described below

commit 76c9516417d1886fd0378247837eed8fff6cec6a
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri Dec 26 16:39:34 2025 +0800

    [SPARK-54835][SQL] Avoid unnecessary temp QueryExecution for nested command 
execution
    
    ### What changes were proposed in this pull request?
    
    This PR is a small refactor. In DS v2 CRAS/RTAS command, we run a nested 
`AppendData`/`OverwriteByExpression` command by creating a `QueryExecution`. 
This `QueryExecution` will create another temp `QueryExecution` to eagerly 
execute commands. This PR avoids the unnecessary temp `QueryExecution` by using 
`CommandExecutionMode.SKIP` to create `QueryExecution`.
    
    ### Why are the changes needed?
    
    Remove useless temp `QueryExecution` objects.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    cursor 2.2.43
    
    Closes #53596 from cloud-fan/command.
    
    Lead-authored-by: Wenchen Fan <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../spark/sql/execution/QueryExecution.scala       | 33 +++++++++++++++-----
 .../sql/execution/datasources/DataSource.scala     |  4 +--
 .../datasources/v2/WriteToDataSourceV2Exec.scala   |  7 +++--
 .../sql/connector/DataSourceV2DataFrameSuite.scala | 35 ++++++++++++++++++++++
 4 files changed, 68 insertions(+), 11 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 3e0aef962e71..652cf5aa7c3e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -178,13 +178,8 @@ class QueryExecution(
       // with the rest of processing of the root plan being just outputting 
command results,
       // for eagerly executed commands we mark this place as beginning of 
execution.
       tracker.setReadyForExecution()
-      val qe = new QueryExecution(sparkSession, p, mode = mode,
-        shuffleCleanupMode = shuffleCleanupMode, refreshPhaseEnabled = 
refreshPhaseEnabled)
-      val result = QueryExecution.withInternalError(s"Eagerly executed $name 
failed.") {
-        SQLExecution.withNewExecutionId(qe, Some(name)) {
-          qe.executedPlan.executeCollect()
-        }
-      }
+      val (qe, result) = QueryExecution.runCommand(
+        sparkSession, p, name, refreshPhaseEnabled, mode, 
Some(shuffleCleanupMode))
       CommandResult(
         qe.analyzed.output,
         qe.commandExecuted,
@@ -763,4 +758,28 @@ object QueryExecution {
       case _ => false
     }
   }
+
+  def runCommand(
+      sparkSession: SparkSession,
+      command: LogicalPlan,
+      name: String,
+      refreshPhaseEnabled: Boolean = true,
+      mode: CommandExecutionMode.Value = CommandExecutionMode.SKIP,
+      shuffleCleanupModeOpt: Option[ShuffleCleanupMode] = None)
+    : (QueryExecution, Array[InternalRow]) = {
+    val shuffleCleanupMode = shuffleCleanupModeOpt.getOrElse(
+      determineShuffleCleanupMode(sparkSession.sessionState.conf))
+    val qe = new QueryExecution(
+      sparkSession,
+      command,
+      mode = mode,
+      shuffleCleanupMode = shuffleCleanupMode,
+      refreshPhaseEnabled = refreshPhaseEnabled)
+    val result = QueryExecution.withInternalError(s"Executed $name failed.") {
+      SQLExecution.withNewExecutionId(qe, Some(name)) {
+        qe.executedPlan.executeCollect()
+      }
+    }
+    (qe, result)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index da3204eb221b..35588df11bfc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -39,6 +39,7 @@ import 
org.apache.spark.sql.classic.ClassicConversions.castToImpl
 import org.apache.spark.sql.classic.Dataset
 import org.apache.spark.sql.connector.catalog.TableProvider
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.command.DataWritingCommand
 import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
 import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
@@ -531,8 +532,7 @@ case class DataSource(
         disallowWritingIntervals(
           outputColumns.toStructType.asNullable, format.toString, 
forbidAnsiIntervals = false)
         val cmd = planForWritingFileFormat(format, mode, data)
-        val qe = sessionState(sparkSession).executePlan(cmd)
-        qe.assertCommandExecuted()
+        QueryExecution.runCommand(sparkSession, cmd, "file source write")
         // Replace the schema with that of the DataFrame we just wrote out to 
avoid re-inferring
         copy(userSpecifiedSchema = 
Some(outputColumns.toStructType.asNullable)).resolveRelation()
       case _ => throw SparkException.internalError(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 3e4a2f792a1c..6d874fb29e98 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -737,8 +737,11 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends 
LeafV2CommandExec {
       } else {
         AppendData.byPosition(relation, query, writeOptions)
       }
-      val qe = QueryExecution.create(session, writeCommand, 
refreshPhaseEnabled)
-      qe.assertCommandExecuted()
+      QueryExecution.runCommand(
+        session,
+        writeCommand,
+        "inner data writing for CTAS/RTAS",
+        refreshPhaseEnabled)
       DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics)
       Nil
     })(catchBlock = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
index c3164b3428f9..3a327fcf9863 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala
@@ -2053,4 +2053,39 @@ class DataSourceV2DataFrameSuite
       case _ => fail(s"can't pin $ident in $catalogName")
     }
   }
+
+  test("CTAS/RTAS should trigger two query executions") {
+    // CTAS/RTAS triggers 2 query executions:
+    // 1. The outer CTAS/RTAS command execution
+    // 2. The inner AppendData/OverwriteByExpression execution
+    var executionCount = 0
+    val listener = new QueryExecutionListener {
+      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: 
Long): Unit = {
+        executionCount += 1
+      }
+      override def onFailure(funcName: String, qe: QueryExecution, exception: 
Exception): Unit = {}
+    }
+
+    try {
+      spark.listenerManager.register(listener)
+      val t = "testcat.ns1.ns2.tbl"
+      withTable(t) {
+        // Test CTAS (CreateTableAsSelect)
+        executionCount = 0
+        sql(s"CREATE TABLE $t USING foo AS SELECT 1 as id, 'a' as data")
+        sparkContext.listenerBus.waitUntilEmpty()
+        assert(executionCount == 2,
+          s"CTAS should trigger 2 executions, got $executionCount")
+
+        // Test RTAS (ReplaceTableAsSelect)
+        executionCount = 0
+        sql(s"CREATE OR REPLACE TABLE $t USING foo AS SELECT 2 as id, 'b' as 
data")
+        sparkContext.listenerBus.waitUntilEmpty()
+        assert(executionCount == 2,
+          s"RTAS should trigger 2 executions, got $executionCount")
+      }
+    } finally {
+      spark.listenerManager.unregister(listener)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to