This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1368983b58f1 [SPARK-52134][SQL][FOLLOW UP] SQL Script execution code -
refactor follow up
1368983b58f1 is described below
commit 1368983b58f1f49e7e420405a4f399471dfb1512
Author: David Milicevic <[email protected]>
AuthorDate: Thu Oct 30 10:02:37 2025 +0800
[SPARK-52134][SQL][FOLLOW UP] SQL Script execution code - refactor follow up
### What changes were proposed in this pull request?
Original change was introduced in
https://github.com/apache/spark/pull/50895. This change is a follow-up to:
1. Restructure the code better.
2. Add more meaningful Spark Connect tests and move them to a better suite.
### Why are the changes needed?
Cleaner and more consistent code. No functional changes.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
There shouldn't be any functional changes. Currently existing tests should
be enough to confirm that change works as expected.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52786 from davidm-db/davidm-db/execute_sql_script_refactor_v2.
Authored-by: David Milicevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/connect/ClientE2ETestSuite.scala | 56 ++++++++++++++++++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +-
.../spark/sql/connect/SparkConnectServerTest.scala | 42 ----------------
.../service/SparkConnectServiceE2ESuite.scala | 21 --------
.../spark/sql/execution/QueryExecution.scala | 56 ++++++++++++++++++----
5 files changed, 105 insertions(+), 74 deletions(-)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index db165c03ad35..8c336b6fa6d5 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1792,6 +1792,62 @@ class ClientE2ETestSuite
assert(result.length === 1)
assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null))
}
+
+ // SQL Scripting tests
+ test("SQL Script result") {
+ val df = spark.sql("""BEGIN
+ | IF 1=1 THEN
+ | SELECT 1;
+ | ELSE
+ | SELECT 3;
+ | END IF;
+ |END
+ |""".stripMargin)
+ checkAnswer(df, Seq(Row(1)))
+ }
+
+ test("SQL Script schema") {
+ withTable("script_tbl") {
+ val df = spark.sql("""BEGIN
+ | CREATE TABLE script_tbl (a INT, b STRING);
+ | INSERT INTO script_tbl VALUES (1, 'Hello'), (2, 'World');
+ | SELECT * FROM script_tbl;
+ |END
+ |""".stripMargin)
+ assert(
+ df.schema == StructType(
+ StructField("a", IntegerType, nullable = true)
+ :: StructField("b", StringType, nullable = true)
+ :: Nil))
+ }
+ }
+
+ test("SQL Script empty result") {
+ withTable("script_tbl") {
+ val df = spark.sql("""BEGIN
+ | CREATE TABLE script_tbl (a INT, b STRING);
+ | SELECT * FROM script_tbl;
+ |END
+ |""".stripMargin)
+ assert(
+ df.schema == StructType(
+ StructField("a", IntegerType, nullable = true)
+ :: StructField("b", StringType, nullable = true)
+ :: Nil))
+ checkAnswer(df, Seq.empty)
+ }
+ }
+
+ test("SQL Script no result") {
+ withTable("script_tbl") {
+ val df = spark.sql("""BEGIN
+ | CREATE TABLE script_tbl (a INT, b STRING);
+ |END
+ |""".stripMargin)
+ assert(df.schema == StructType(Nil))
+ checkAnswer(df, Seq.empty)
+ }
+ }
}
private[sql] case class ClassData(a: String, b: Int)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ebcf462b84ce..41efe8db842f 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -52,7 +52,7 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.parser.{NamedParameterContext,
ParameterContext, ParseException, ParserUtils, PositionalParameterContext}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment,
CoGroup, CollectMetrics, CommandResult, CompoundBody, Deduplicate,
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except,
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias,
TimeModes, TransformWithState, TypedFilter, Union, Un [...]
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, Assignment,
CoGroup, CollectMetrics, CommandResult, Deduplicate,
DeduplicateWithinWatermark, DeleteAction, DeserializeToObject, Except,
FlatMapGroupsWithState, InsertAction, InsertStarAction, Intersect, JoinWith,
LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions,
MergeAction, Project, Sample, SerializeFromObject, Sort, SubqueryAlias,
TimeModes, TransformWithState, TypedFilter, Union, Unpivot, Unresol [...]
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -2975,7 +2975,7 @@ class SparkConnectPlanner(
// Check if command or SQL Script has been executed.
val isCommand =
df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
- val isSqlScript = df.queryExecution.logical.isInstanceOf[CompoundBody]
+ val isSqlScript = df.queryExecution.isSqlScript
val rows = df.logicalPlan match {
case lr: LocalRelation => lr.data
case cr: CommandResult => cr.rows
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 91e728d73e13..1b2b7ab42029 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -16,19 +16,16 @@
*/
package org.apache.spark.sql.connect
-import java.io.ByteArrayInputStream
import java.util.{TimeZone, UUID}
import scala.reflect.runtime.universe.TypeTag
import org.apache.arrow.memory.RootAllocator
-import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.scalatest.concurrent.{Eventually, TimeLimits}
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.connect.client.{CloseableIterator,
CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator,
RetryPolicy, SparkConnectClient, SparkConnectStubState}
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
@@ -323,43 +320,4 @@ trait SparkConnectServerTest extends SharedSparkSession {
val plan = buildPlan(query)
runQuery(plan, queryTimeout, iterSleep)
}
-
- protected def checkSqlCommandResponse(
- result: ExecutePlanResponse.SqlCommandResult,
- expected: Seq[Seq[Any]]): Unit = {
- // Extract the serialized Arrow data as a byte array.
- val dataBytes = result.getRelation.getLocalRelation.getData.toByteArray
-
- // Create an ArrowStreamReader to deserialize the data.
- val allocator = new RootAllocator(Long.MaxValue)
- val inputStream = new ByteArrayInputStream(dataBytes)
- val reader = new ArrowStreamReader(inputStream, allocator)
-
- try {
- // Read the schema and data.
- val root = reader.getVectorSchemaRoot
- // Load the first batch of data.
- reader.loadNextBatch()
-
- // Get dimensions.
- val rowCount = root.getRowCount
- val colCount = root.getFieldVectors.size
- assert(rowCount == expected.length, "Row count mismatch")
- assert(colCount == expected.head.length, "Column count mismatch")
-
- // Compare to expected.
- for (i <- 0 until rowCount) {
- for (j <- 0 until colCount) {
- val col = root.getFieldVectors.get(j)
- val value = col.getObject(i)
- print(value)
- assert(value == expected(i)(j), s"Value mismatch at ($i, $j)")
- }
- }
- } finally {
- // Clean up resources.
- reader.close()
- allocator.close()
- }
- }
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index e3ba35073f41..0e18ff711c4c 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -33,27 +33,6 @@ class SparkConnectServiceE2ESuite extends
SparkConnectServerTest {
// were all already in the buffer.
val BIG_ENOUGH_QUERY = "select * from range(1000000)"
- test("SQL Script over Spark Connect.") {
- val sessionId = UUID.randomUUID.toString()
- val userId = "ScriptUser"
- val sqlScriptText =
- """BEGIN
- |IF 1 = 1 THEN
- | SELECT 1;
- |ELSE
- | SELECT 2;
- |END IF;
- |END
- """.stripMargin
- withClient(sessionId = sessionId, userId = userId) { client =>
- // this will create the session, and then ReleaseSession at the end of
withClient.
- val enableSqlScripting = client.execute(buildPlan("SET
spark.sql.scripting.enabled=true"))
- enableSqlScripting.hasNext // trigger execution
- val query = client.execute(buildSqlCommandPlan(sqlScriptText))
- checkSqlCommandResponse(query.next().getSqlCommandResult, Seq(Seq(1)))
- }
- }
-
test("Execute is sent eagerly to the server upon iterator creation") {
// This behavior changed with grpc upgrade from 1.56.0 to 1.59.0.
// Testing to be aware of future changes.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b46172001a87..27d6eec46b69 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -72,6 +72,12 @@ class QueryExecution(
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
+ /**
+ * Check whether the query represented by this QueryExecution is a SQL
script.
+ * @return True if the query is a SQL script, False otherwise.
+ */
+ def isSqlScript: Boolean = QueryExecution.isUnresolvedPlanSqlScript(logical)
+
lazy val isLazyAnalysis: Boolean = {
// Only check the main query as subquery expression can be resolved now
with the main query.
logical.exists(_.expressions.exists(_.exists(_.isInstanceOf[LazyExpression])))
@@ -95,27 +101,46 @@ class QueryExecution(
}
}
- private val lazyAnalyzed = LazyTry {
- val withScriptExecuted = logical match {
- // Execute the SQL script. Script doesn't need to go through the
analyzer as Spark will run
- // each statement as individual query.
+ /**
+ * Execute the SQL script if the logical plan is a SQL script.
+ * There are multiple cases, and they are originating from:
+ * - SparkSession.sql() - Spark and Spark Connect case
+ * - QueryRuntimePredictionUtils.getParsedPlanWithTracking() - DBSQL case
+ */
+ private val lazySqlScriptExecuted = LazyTry {
+ logical match {
case NameParameterizedQuery(compoundBody: CompoundBody, argNames,
argValues) =>
- val args = argNames.zip(argValues).toMap
- SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody,
args)
+ SqlScriptingExecution.executeSqlScript(
+ session = sparkSession,
+ script = compoundBody,
+ args = argNames.zip(argValues).toMap)
case compoundBody: CompoundBody =>
- SqlScriptingExecution.executeSqlScript(sparkSession, compoundBody)
+ SqlScriptingExecution.executeSqlScript(
+ session = sparkSession,
+ script = compoundBody)
case _ => logical
}
+ }
+
+ private def sqlScriptExecuted: LogicalPlan = lazySqlScriptExecuted.get
+
+ private def assertSqlScriptExecuted(): Unit = sqlScriptExecuted
+
+ private val lazyAnalyzed = LazyTry {
+ // Execute the SQL script. Script doesn't need to go through the analyzer
as Spark
+ // will run each statement as individual query.
+ assertSqlScriptExecuted()
+
try {
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
// We can't clone `logical` here, which will reset the `_analyzed`
flag.
- sparkSession.sessionState.analyzer.executeAndCheck(withScriptExecuted,
tracker)
+ sparkSession.sessionState.analyzer.executeAndCheck(sqlScriptExecuted,
tracker)
}
tracker.setAnalyzed(plan)
plan
} catch {
case NonFatal(e) =>
- tracker.setAnalysisFailed(withScriptExecuted)
+ tracker.setAnalysisFailed(sqlScriptExecuted)
throw e
}
}
@@ -699,4 +724,17 @@ object QueryExecution {
DoNotCleanup
}
}
+
+ /**
+ * Determines whether the given unresolved plan is a SQL script.
+ * @param plan Logical plan to check.
+ * @return True if the plan is a SQL script, False otherwise.
+ */
+ def isUnresolvedPlanSqlScript(plan: LogicalPlan): Boolean = {
+ plan match {
+ case _: CompoundBody => true
+ case NameParameterizedQuery(_: CompoundBody, _, _) => true
+ case _ => false
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]