This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1368983b58f1 [SPARK-52134][SQL][FOLLOW UP] SQL Script execution code - 
refactor follow up
1368983b58f1 is described below

commit 1368983b58f1f49e7e420405a4f399471dfb1512
Author: David Milicevic <[email protected]>
AuthorDate: Thu Oct 30 10:02:37 2025 +0800

    [SPARK-52134][SQL][FOLLOW UP] SQL Script execution code - refactor follow up
    
    ### What changes were proposed in this pull request?
    
    Original change was introduced in 
https://github.com/apache/spark/pull/50895. This change is a follow-up to:
    1. Restructure the code better.
    2. Add more meaningful Spark Connect tests and move them to a better suite.
    
    ### Why are the changes needed?
    
    Cleaner and more consistent code. No functional changes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    There shouldn't be any functional changes. Currently existing tests should 
be enough to confirm that change works as expected.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52786 from davidm-db/davidm-db/execute_sql_script_refactor_v2.
    
    Authored-by: David Milicevic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/connect/ClientE2ETestSuite.scala     | 56 ++++++++++++++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  4 +-
 .../spark/sql/connect/SparkConnectServerTest.scala | 42 ----------------
 .../service/SparkConnectServiceE2ESuite.scala      | 21 --------
 .../spark/sql/execution/QueryExecution.scala       | 56 ++++++++++++++++++----
 5 files changed, 105 insertions(+), 74 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index db165c03ad35..8c336b6fa6d5 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1792,6 +1792,62 @@ class ClientE2ETestSuite
     assert(result.length === 1)
     assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null))
   }
+
+  // SQL Scripting tests
+  test("SQL Script result") {
+    val df = spark.sql("""BEGIN
+        |  IF 1=1 THEN
+        |    SELECT 1;
+        |  ELSE
+        |    SELECT 3;
+        |  END IF;
+        |END
+        |""".stripMargin)
+    checkAnswer(df, Seq(Row(1)))
+  }
+
+  test("SQL Script schema") {
+    withTable("script_tbl") {
+      val df = spark.sql("""BEGIN
+          |  CREATE TABLE script_tbl (a INT, b STRING);
+          |  INSERT INTO script_tbl VALUES (1, 'Hello'), (2, 'World');
+          |  SELECT * FROM script_tbl;
+          |END
+          |""".stripMargin)
+      assert(
+        df.schema == StructType(
+          StructField("a", IntegerType, nullable = true)
+            :: StructField("b", StringType, nullable = true)
+            :: Nil))
+    }
+  }
+
+  test("SQL Script empty result") {
+    withTable("script_tbl") {
+      val df = spark.sql("""BEGIN
+          |  CREATE TABLE script_tbl (a INT, b STRING);
+          |  SELECT * FROM script_tbl;
+          |END
+          |""".stripMargin)
+      assert(
+        df.schema == StructType(
+          StructField("a", IntegerType, nullable = true)
+            :: StructField("b", StringType, nullable = true)
+            :: Nil))
+      checkAnswer(df, Seq.empty)
+    }
+  }
+
+  test("SQL Script no result") {
+    withTable("script_tbl") {
+      val df = spark.sql("""BEGIN
+          |  CREATE TABLE script_tbl (a INT, b STRING);
+          |END
+          |""".stripMargin)
+      assert(df.schema == StructType(Nil))
+      checkAnswer(df, Seq.empty)
+    }
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ebcf462b84ce..41efe8db842f 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -52,7 +52,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.parser.{NamedParameterContext, 
ParameterContext, ParseException, ParserUtils, PositionalParameterContext}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, CompoundBody, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TimeModes, TransformWithState, TypedFilter, Union, Un [...]
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment, 
CoGroup, CollectMetrics, CommandResult, Deduplicate, 
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except, 
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith, 
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, 
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, 
TimeModes, TransformWithState, TypedFilter, Union, Unpivot, Unresol [...]
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -2975,7 +2975,7 @@ class SparkConnectPlanner(
 
     // Check if command or SQL Script has been executed.
     val isCommand = 
df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
-    val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
+    val isSqlScript = df.queryExecution.isSqlScript
     val rows = df.logicalPlan match {
       case lr: LocalRelation => lr.data
       case cr: CommandResult => cr.rows
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 91e728d73e13..1b2b7ab42029 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -16,19 +16,16 @@
  */
 package org.apache.spark.sql.connect
 
-import java.io.ByteArrayInputStream
 import java.util.{TimeZone, UUID}
 
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.arrow.memory.RootAllocator
-import org.apache.arrow.vector.ipc.ArrowStreamReader
 import org.scalatest.concurrent.{Eventually, TimeLimits}
 import org.scalatest.time.Span
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.ExecutePlanResponse
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.connect.client.{CloseableIterator, 
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, 
RetryPolicy, SparkConnectClient, SparkConnectStubState}
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
@@ -323,43 +320,4 @@ trait SparkConnectServerTest extends SharedSparkSession {
     val plan = buildPlan(query)
     runQuery(plan, queryTimeout, iterSleep)
   }
-
-  protected def checkSqlCommandResponse(
-      result: ExecutePlanResponse.SqlCommandResult,
-      expected: Seq[Seq[Any]]): Unit = {
-    // Extract the serialized Arrow data as a byte array.
-    val dataBytes = result.getRelation.getLocalRelation.getData.toByteArray
-
-    // Create an ArrowStreamReader to deserialize the data.
-    val allocator = new RootAllocator(Long.MaxValue)
-    val inputStream = new ByteArrayInputStream(dataBytes)
-    val reader = new ArrowStreamReader(inputStream, allocator)
-
-    try {
-      // Read the schema and data.
-      val root = reader.getVectorSchemaRoot
-      // Load the first batch of data.
-      reader.loadNextBatch()
-
-      // Get dimensions.
-      val rowCount = root.getRowCount
-      val colCount = root.getFieldVectors.size
-      assert(rowCount == expected.length, "Row count mismatch")
-      assert(colCount == expected.head.length, "Column count mismatch")
-
-      // Compare to expected.
-      for (i <- 0 until rowCount) {
-        for (j <- 0 until colCount) {
-          val col = root.getFieldVectors.get(j)
-          val value = col.getObject(i)
-          print(value)
-          assert(value == expected(i)(j), s"Value mismatch at ($i, $j)")
-        }
-      }
-    } finally {
-      // Clean up resources.
-      reader.close()
-      allocator.close()
-    }
-  }
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index e3ba35073f41..0e18ff711c4c 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -33,27 +33,6 @@ class SparkConnectServiceE2ESuite extends 
SparkConnectServerTest {
   // were all already in the buffer.
   val BIG_ENOUGH_QUERY = "select * from range(1000000)"
 
-  test("SQL Script over Spark Connect.") {
-    val sessionId = UUID.randomUUID.toString()
-    val userId = "ScriptUser"
-    val sqlScriptText =
-      """BEGIN
-        |IF 1 = 1 THEN
-        |  SELECT 1;
-        |ELSE
-        |  SELECT 2;
-        |END IF;
-        |END
-        """.stripMargin
-    withClient(sessionId = sessionId, userId = userId) { client =>
-      // this will create the session, and then ReleaseSession at the end of 
withClient.
-      val enableSqlScripting = client.execute(buildPlan("SET 
spark.sql.scripting.enabled=true"))
-      enableSqlScripting.hasNext // trigger execution
-      val query = client.execute(buildSqlCommandPlan(sqlScriptText))
-      checkSqlCommandResponse(query.next().getSqlCommandResult, Seq(Seq(1)))
-    }
-  }
-
   test("Execute is sent eagerly to the server upon iterator creation") {
     // This behavior changed with grpc upgrade from 1.56.0 to 1.59.0.
     // Testing to be aware of future changes.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b46172001a87..27d6eec46b69 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -72,6 +72,12 @@ class QueryExecution(
   // TODO: Move the planner an optimizer into here from SessionState.
   protected def planner = sparkSession.sessionState.planner
 
+  /**
+   * Check whether the query represented by this QueryExecution is a SQL 
script.
+   * @return True if the query is a SQL script, False otherwise.
+   */
+  def isSqlScript: Boolean = QueryExecution.isUnresolvedPlanSqlScript(logical)
+
   lazy val isLazyAnalysis: Boolean = {
     // Only check the main query as subquery expression can be resolved now 
with the main query.
     
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression])))
@@ -95,27 +101,46 @@ class QueryExecution(
     }
   }
 
-  private val lazyAnalyzed = LazyTry {
-    val withScriptExecuted = logical match {
-      // Execute the SQL script. Script doesn't need to go through the 
analyzer as Spark will run
-      // each statement as individual query.
+  /**
+   * Execute the SQL script if the logical plan is a SQL script.
+   * There are multiple cases, and they are originating from:
+   *   - SparkSession.sql() - Spark and Spark Connect case
+   *   - QueryRuntimePredictionUtils.getParsedPlanWithTracking() - DBSQL case
+   */
+  private val lazySqlScriptExecuted = LazyTry {
+    logical match {
       case NameParameterizedQuery(compoundBody: CompoundBody, argNames, 
argValues) =>
-        val args = argNames.zip(argValues).toMap
-        SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody, 
args)
+        SqlScriptingExecution.executeSqlScript(
+          session = sparkSession,
+          script = compoundBody,
+          args = argNames.zip(argValues).toMap)
       case compoundBody: CompoundBody =>
-        SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody)
+        SqlScriptingExecution.executeSqlScript(
+          session = sparkSession,
+          script = compoundBody)
       case _ => logical
     }
+  }
+
+  private def sqlScriptExecuted: LogicalPlan = lazySqlScriptExecuted.get
+
+  private def assertSqlScriptExecuted(): Unit = sqlScriptExecuted
+
+  private val lazyAnalyzed = LazyTry {
+    // Execute the SQL script. Script doesn't need to go through the analyzer 
as Spark
+    //   will run each statement as individual query.
+    assertSqlScriptExecuted()
+
     try {
       val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
         // We can't clone `logical` here, which will reset the `_analyzed` 
flag.
-        sparkSession.sessionState.analyzer.executeAndCheck(withScriptExecuted, 
tracker)
+        sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted, 
tracker)
       }
       tracker.setAnalyzed(plan)
       plan
     } catch {
       case NonFatal(e) =>
-        tracker.setAnalysisFailed(withScriptExecuted)
+        tracker.setAnalysisFailed(sqlScriptExecuted)
         throw e
     }
   }
@@ -699,4 +724,17 @@ object QueryExecution {
       DoNotCleanup
     }
   }
+
+  /**
+   * Determines whether the given unresolved plan is a SQL script.
+   * @param plan Logical plan to check.
+   * @return True if the plan is a SQL script, False otherwise.
+   */
+  def isUnresolvedPlanSqlScript(plan: LogicalPlan): Boolean = {
+    plan match {
+      case _: CompoundBody => true
+      case NameParameterizedQuery(_: CompoundBody, _, _) => true
+      case _ => false
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to