This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 30cba12a51fc [SPARK-48344][SQL] Add SQL Scripting Execution Framework
30cba12a51fc is described below
commit 30cba12a51fcfda7fe42089e077ae53504be946e
Author: Milan Dankovic <[email protected]>
AuthorDate: Wed Nov 27 12:40:50 2024 +0800
[SPARK-48344][SQL] Add SQL Scripting Execution Framework
### What changes were proposed in this pull request?
This PR is second in series of refactoring Initial refactoring of SQL
Scripting to prepare it for addition of **Execution Framework**:
- Introducing `SqlScriptExecution`, new iterator that collects and returns
only results.
- Enabling execution of SQL Scripting using `sql()` API.
- Enabling named parameters to be used with SQL Scripting.
### Why are the changes needed?
This changes are needed to enable execution of SQL Scripts.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New `SqlScriptingExecutionSuite` to test behavior of newly added component.
New `SqlScriptingE2eSuite` to test end to end behavior using `sql()` API.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48950 from miland-db/milan-dankovic_data/refactor-execution-2.
Authored-by: Milan Dankovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 5 +
.../spark/sql/errors/SqlScriptingErrors.scala | 8 +
.../scala/org/apache/spark/sql/SparkSession.scala | 92 +++-
.../sql/scripting/SqlScriptingExecution.scala | 92 ++++
.../sql/scripting/SqlScriptingExecutionNode.scala | 30 +-
.../sql/scripting/SqlScriptingInterpreter.scala | 67 ++-
.../spark/sql/scripting/SqlScriptingE2eSuite.scala | 188 +++++++
.../scripting/SqlScriptingExecutionNodeSuite.scala | 2 +
...uite.scala => SqlScriptingExecutionSuite.scala} | 597 ++-------------------
.../scripting/SqlScriptingInterpreterSuite.scala | 9 +-
10 files changed, 507 insertions(+), 583 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 94513cca1023..3c494704fd71 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5381,6 +5381,11 @@
"SQL Scripting is under development and not all features are
supported. SQL Scripting enables users to write procedural SQL including
control flow and error handling. To enable existing features set
<sqlScriptingEnabled> to `true`."
]
},
+ "SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS" : {
+ "message" : [
+ "Positional parameters are not supported with SQL Scripting."
+ ]
+ },
"STATE_STORE_MULTIPLE_COLUMN_FAMILIES" : {
"message" : [
"Creating multiple column families with <stateStoreProvider> is not
supported."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
index f1c07200d503..2a4b8fde6989 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
@@ -103,6 +103,14 @@ private[sql] object SqlScriptingErrors {
messageParameters = Map("invalidStatement" -> toSQLStmt(stmt)))
}
+ def positionalParametersAreNotSupportedWithSqlScripting(): Throwable = {
+ new SqlScriptingException(
+ origin = null,
+ errorClass =
"UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS",
+ cause = null,
+ messageParameters = Map.empty)
+ }
+
def labelDoesNotExist(
origin: Origin,
labelName: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 8cf30fb39f31..dbe4543c3310 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -44,17 +44,19 @@ import
org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParame
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, NamedExpression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
+import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody,
LocalRelation, LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.ExternalCommandRunner
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExternalCommandExecutor
import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal._
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.scripting.SqlScriptingExecution
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -431,6 +433,42 @@ class SparkSession private(
| Everything else |
* ----------------- */
+ /**
+ * Executes given script and return the result of the last statement.
+ * If script contains no queries, an empty `DataFrame` is returned.
+ *
+ * @param script A SQL script to execute.
+ * @param args A map of parameter names to SQL literal expressions.
+ *
+ * @return The result as a `DataFrame`.
+ */
+ private def executeSqlScript(
+ script: CompoundBody,
+ args: Map[String, Expression] = Map.empty): DataFrame = {
+ val sse = new SqlScriptingExecution(script, this, args)
+ var result: Option[Seq[Row]] = None
+
+ while (sse.hasNext) {
+ sse.withErrorHandling {
+ val df = sse.next()
+ if (sse.hasNext) {
+ df.write.format("noop").mode("overwrite").save()
+ } else {
+ // Collect results from the last DataFrame.
+ result = Some(df.collect().toSeq)
+ }
+ }
+ }
+
+ if (result.isEmpty) {
+ emptyDataFrame
+ } else {
+ val attributes = DataTypeUtils.toAttributes(result.get.head.schema)
+ Dataset.ofRows(
+ self, LocalRelation.fromExternalRows(attributes, result.get))
+ }
+ }
+
/**
* Executes a SQL query substituting positional parameters by the given
arguments,
* returning the result as a `DataFrame`.
@@ -450,13 +488,30 @@ class SparkSession private(
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
- if (args.nonEmpty) {
- PosParameterizedQuery(parsedPlan,
args.map(lit(_).expr).toImmutableArraySeq)
- } else {
- parsedPlan
+ parsedPlan match {
+ case compoundBody: CompoundBody =>
+ if (args.nonEmpty) {
+ // Positional parameters are not supported for SQL scripting.
+ throw
SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
+ }
+ compoundBody
+ case logicalPlan: LogicalPlan =>
+ if (args.nonEmpty) {
+ PosParameterizedQuery(logicalPlan,
args.map(lit(_).expr).toImmutableArraySeq)
+ } else {
+ logicalPlan
+ }
}
}
- Dataset.ofRows(self, plan, tracker)
+
+ plan match {
+ case compoundBody: CompoundBody =>
+ // Execute the SQL script.
+ executeSqlScript(compoundBody)
+ case logicalPlan: LogicalPlan =>
+ // Execute the standalone SQL statement.
+ Dataset.ofRows(self, plan, tracker)
+ }
}
/** @inheritdoc */
@@ -487,13 +542,26 @@ class SparkSession private(
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
val parsedPlan = sessionState.sqlParser.parsePlan(sqlText)
- if (args.nonEmpty) {
- NameParameterizedQuery(parsedPlan, args.transform((_, v) =>
lit(v).expr))
- } else {
- parsedPlan
+ parsedPlan match {
+ case compoundBody: CompoundBody =>
+ compoundBody
+ case logicalPlan: LogicalPlan =>
+ if (args.nonEmpty) {
+ NameParameterizedQuery(logicalPlan, args.transform((_, v) =>
lit(v).expr))
+ } else {
+ logicalPlan
+ }
}
}
- Dataset.ofRows(self, plan, tracker)
+
+ plan match {
+ case compoundBody: CompoundBody =>
+ // Execute the SQL script.
+ executeSqlScript(compoundBody, args.transform((_, v) => lit(v).expr))
+ case logicalPlan: LogicalPlan =>
+ // Execute the standalone SQL statement.
+ Dataset.ofRows(self, plan, tracker)
+ }
}
/** @inheritdoc */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
new file mode 100644
index 000000000000..59252f622918
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.scripting
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.{CommandResult,
CompoundBody}
+
+/**
+ * SQL scripting executor - executes script and returns result statements.
+ * This supports returning multiple result statements from a single script.
+ *
+ * @param sqlScript CompoundBody which need to be executed.
+ * @param session Spark session that SQL script is executed within.
+ * @param args A map of parameter names to SQL literal expressions.
+ */
+class SqlScriptingExecution(
+ sqlScript: CompoundBody,
+ session: SparkSession,
+ args: Map[String, Expression]) extends Iterator[DataFrame] {
+
+ // Build the execution plan for the script.
+ private val executionPlan: Iterator[CompoundStatementExec] =
+ SqlScriptingInterpreter(session).buildExecutionPlan(sqlScript, args)
+
+ private var current = getNextResult
+
+ override def hasNext: Boolean = current.isDefined
+
+ override def next(): DataFrame = {
+ if (!hasNext) throw SparkException.internalError("No more elements to
iterate through.")
+ val nextDataFrame = current.get
+ current = getNextResult
+ nextDataFrame
+ }
+
+ /** Helper method to iterate through statements until next result statement
is encountered. */
+ private def getNextResult: Option[DataFrame] = {
+
+ def getNextStatement: Option[CompoundStatementExec] =
+ if (executionPlan.hasNext) Some(executionPlan.next()) else None
+
+ var currentStatement = getNextStatement
+ // While we don't have a result statement, execute the statements.
+ while (currentStatement.isDefined) {
+ currentStatement match {
+ case Some(stmt: SingleStatementExec) if !stmt.isExecuted =>
+ withErrorHandling {
+ val df = stmt.buildDataFrame(session)
+ df.logicalPlan match {
+ case _: CommandResult => // pass
+ case _ => return Some(df) // If the statement is a result,
return it to the caller.
+ }
+ }
+ case _ => // pass
+ }
+ currentStatement = getNextStatement
+ }
+ None
+ }
+
+ private def handleException(e: Throwable): Unit = {
+ // Rethrow the exception.
+ // TODO: SPARK-48353 Add error handling for SQL scripts
+ throw e
+ }
+
+ def withErrorHandling(f: => Unit): Unit = {
+ try {
+ f
+ } catch {
+ case e: Throwable =>
+ handleException(e)
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 9129fc6ab00f..94284ec514f5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -19,7 +19,9 @@ package org.apache.spark.sql.scripting
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.NameParameterizedQuery
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -77,7 +79,7 @@ trait NonLeafStatementExec extends CompoundStatementExec {
// DataFrame evaluates to True if it is single row, single column
// of boolean type with value True.
- val df = Dataset.ofRows(session, statement.parsedPlan)
+ val df = statement.buildDataFrame(session)
df.schema.fields match {
case Array(field) if field.dataType == BooleanType =>
df.limit(2).collect() match {
@@ -105,6 +107,8 @@ trait NonLeafStatementExec extends CompoundStatementExec {
* Logical plan of the parsed statement.
* @param origin
* Origin descriptor for the statement.
+ * @param args
+ * A map of parameter names to SQL literal expressions.
* @param isInternal
* Whether the statement originates from the SQL script or it is created
during the
* interpretation. Example: DropVariable statements are automatically
created at the end of each
@@ -113,6 +117,7 @@ trait NonLeafStatementExec extends CompoundStatementExec {
class SingleStatementExec(
var parsedPlan: LogicalPlan,
override val origin: Origin,
+ val args: Map[String, Expression],
override val isInternal: Boolean)
extends LeafStatementExec with WithOrigin {
@@ -122,6 +127,17 @@ class SingleStatementExec(
*/
var isExecuted = false
+ /**
+ * Plan with named parameters.
+ */
+ private lazy val preparedPlan: LogicalPlan = {
+ if (args.nonEmpty) {
+ NameParameterizedQuery(parsedPlan, args)
+ } else {
+ parsedPlan
+ }
+ }
+
/**
* Get the SQL query text corresponding to this statement.
* @return
@@ -132,6 +148,16 @@ class SingleStatementExec(
origin.sqlText.get.substring(origin.startIndex.get, origin.stopIndex.get +
1)
}
+ /**
+ * Builds a DataFrame from the parsedPlan of this SingleStatementExec
+ * @param session The SparkSession on which the parsedPlan is built.
+ * @return
+ * The DataFrame.
+ */
+ def buildDataFrame(session: SparkSession): DataFrame = {
+ Dataset.ofRows(session, preparedPlan)
+ }
+
override def reset(): Unit = isExecuted = false
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 1be75cb61c8b..387ae36b881f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,13 +19,17 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement,
CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable,
IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement,
RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.trees.Origin
/**
* SQL scripting interpreter - builds SQL script execution plan.
+ *
+ * @param session
+ * Spark session that SQL script is executed within.
*/
-case class SqlScriptingInterpreter() {
+case class SqlScriptingInterpreter(session: SparkSession) {
/**
* Build execution plan and return statements that need to be executed,
@@ -33,15 +37,16 @@ case class SqlScriptingInterpreter() {
*
* @param compound
* CompoundBody for which to build the plan.
- * @param session
- * Spark session that SQL script is executed within.
+ * @param args
+ * A map of parameter names to SQL literal expressions.
* @return
* Iterator through collection of statements to be executed.
*/
def buildExecutionPlan(
compound: CompoundBody,
- session: SparkSession): Iterator[CompoundStatementExec] = {
- transformTreeIntoExecutable(compound,
session).asInstanceOf[CompoundBodyExec].getTreeIterator
+ args: Map[String, Expression]): Iterator[CompoundStatementExec] = {
+ transformTreeIntoExecutable(compound, args)
+ .asInstanceOf[CompoundBodyExec].getTreeIterator
}
/**
@@ -62,13 +67,14 @@ case class SqlScriptingInterpreter() {
*
* @param node
* Root node of the parsed tree.
- * @param session
- * Spark session that SQL script is executed within.
+ * @param args
+ * A map of parameter names to SQL literal expressions.
* @return
* Executable statement.
*/
private def transformTreeIntoExecutable(
- node: CompoundPlanStatement, session: SparkSession):
CompoundStatementExec =
+ node: CompoundPlanStatement,
+ args: Map[String, Expression]): CompoundStatementExec =
node match {
case CompoundBody(collection, label) =>
// TODO [SPARK-48530]: Current logic doesn't support scoped variables
and shadowing.
@@ -78,49 +84,65 @@ case class SqlScriptingInterpreter() {
}
val dropVariables = variables
.map(varName => DropVariable(varName, ifExists = true))
- .map(new SingleStatementExec(_, Origin(), isInternal = true))
+ .map(new SingleStatementExec(_, Origin(), args, isInternal = true))
.reverse
new CompoundBodyExec(
- collection.map(st => transformTreeIntoExecutable(st, session)) ++
dropVariables,
+ collection.map(st => transformTreeIntoExecutable(st, args)) ++
dropVariables,
label)
case IfElseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
- new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false))
+ new SingleStatementExec(
+ condition.parsedPlan,
+ condition.origin,
+ args,
+ isInternal = false))
val conditionalBodiesExec = conditionalBodies.map(body =>
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec])
val unconditionalBodiesExec = elseBody.map(body =>
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec])
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
case CaseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
- // todo: what to put here for isInternal, in case of simple case
statement
- new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false))
+ new SingleStatementExec(
+ condition.parsedPlan,
+ condition.origin,
+ args,
+ isInternal = false))
val conditionalBodiesExec = conditionalBodies.map(body =>
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec])
val unconditionalBodiesExec = elseBody.map(body =>
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec])
new CaseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
case WhileStatement(condition, body, label) =>
val conditionExec =
- new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
+ new SingleStatementExec(
+ condition.parsedPlan,
+ condition.origin,
+ args,
+ isInternal = false)
val bodyExec =
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec]
new WhileStatementExec(conditionExec, bodyExec, label, session)
case RepeatStatement(condition, body, label) =>
val conditionExec =
- new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
+ new SingleStatementExec(
+ condition.parsedPlan,
+ condition.origin,
+ args,
+ isInternal = false)
val bodyExec =
- transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
+ transformTreeIntoExecutable(body,
args).asInstanceOf[CompoundBodyExec]
new RepeatStatementExec(conditionExec, bodyExec, label, session)
case LoopStatement(body, label) =>
- val bodyExec = transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
+ val bodyExec = transformTreeIntoExecutable(body, args)
+ .asInstanceOf[CompoundBodyExec]
new LoopStatementExec(bodyExec, label)
case leaveStatement: LeaveStatement =>
@@ -133,6 +155,7 @@ case class SqlScriptingInterpreter() {
new SingleStatementExec(
sparkStatement.parsedPlan,
sparkStatement.origin,
+ args,
isInternal = false)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
new file mode 100644
index 000000000000..afcdfd343e33
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingE2eSuite.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.scripting
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
+import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLConf
+import org.apache.spark.sql.exceptions.SqlScriptingException
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+
+
+/**
+ * End-to-end tests for SQL Scripting.
+ * This suite is not intended to heavily test the SQL scripting (parser &
interpreter) logic.
+ * It is rather focused on testing the sql() API - whether it can handle SQL
scripts correctly,
+ * results are returned in expected manner, config flags are applied
properly, etc.
+ * For full functionality tests, see SqlScriptingParserSuite and
SqlScriptingInterpreterSuite.
+ */
+class SqlScriptingE2eSuite extends QueryTest with SharedSparkSession {
+ // Helpers
+ private def verifySqlScriptResult(sqlText: String, expected: Seq[Row]): Unit
= {
+ val df = spark.sql(sqlText)
+ checkAnswer(df, expected)
+ }
+
+ private def verifySqlScriptResultWithNamedParams(
+ sqlText: String,
+ expected: Seq[Row],
+ args: Map[String, Any]): Unit = {
+ val df = spark.sql(sqlText, args)
+ checkAnswer(df, expected)
+ }
+
+ // Tests setup
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf.set(SQLConf.SQL_SCRIPTING_ENABLED.key, "true")
+ }
+
+ // Tests
+ test("SQL Scripting not enabled") {
+ withSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key -> "false") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | SELECT 1;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ spark.sql(sqlScriptText).asInstanceOf[CompoundBody]
+ },
+ condition = "UNSUPPORTED_FEATURE.SQL_SCRIPTING",
+ parameters = Map("sqlScriptingEnabled" ->
toSQLConf(SQLConf.SQL_SCRIPTING_ENABLED.key)))
+ }
+ }
+
+ test("single select") {
+ val sqlText = "SELECT 1;"
+ verifySqlScriptResult(sqlText, Seq(Row(1)))
+ }
+
+ test("multiple selects") {
+ val sqlText =
+ """
+ |BEGIN
+ | SELECT 1;
+ | SELECT 2;
+ |END""".stripMargin
+ verifySqlScriptResult(sqlText, Seq(Row(2)))
+ }
+
+ test("multi statement - simple") {
+ withTable("t") {
+ val sqlScript =
+ """
+ |BEGIN
+ | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet;
+ | INSERT INTO t VALUES (1, 'a', 1.0);
+ | SELECT a FROM t;
+ |END
+ |""".stripMargin
+ verifySqlScriptResult(sqlScript, Seq(Row(1)))
+ }
+ }
+
+ test("script without result statement") {
+ val sqlScript =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 1;
+ | DROP TEMPORARY VARIABLE x;
+ |END
+ |""".stripMargin
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ }
+
+ test("empty script") {
+ val sqlScript =
+ """
+ |BEGIN
+ |END
+ |""".stripMargin
+ verifySqlScriptResult(sqlScript, Seq.empty)
+ }
+
+ test("named params") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | SELECT 1;
+ | IF :param_1 > 10 THEN
+ | SELECT :param_2;
+ | ELSE
+ | SELECT :param_3;
+ | END IF;
+ |END""".stripMargin
+ // Define a map with SQL parameters
+ val args: Map[String, Any] = Map(
+ "param_1" -> 5,
+ "param_2" -> "greater",
+ "param_3" -> "smaller"
+ )
+ verifySqlScriptResultWithNamedParams(sqlScriptText, Seq(Row("smaller")),
args)
+ }
+
+ test("positional params") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | SELECT 1;
+ | IF ? > 10 THEN
+ | SELECT ?;
+ | ELSE
+ | SELECT ?;
+ | END IF;
+ |END""".stripMargin
+ // Define an array with SQL parameters in the correct order.
+ val args: Array[Any] = Array(5, "greater", "smaller")
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ spark.sql(sqlScriptText, args).asInstanceOf[CompoundBody]
+ },
+ condition =
"UNSUPPORTED_FEATURE.SQL_SCRIPTING_WITH_POSITIONAL_PARAMETERS",
+ parameters = Map.empty)
+ }
+
+ test("named params with positional params - should fail") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | SELECT ?;
+ | IF :param > 10 THEN
+ | SELECT 1;
+ | ELSE
+ | SELECT 2;
+ | END IF;
+ |END""".stripMargin
+ // Define a map with SQL parameters.
+ val args: Map[String, Any] = Map("param" -> 5)
+ checkError(
+ exception = intercept[AnalysisException] {
+ spark.sql(sqlScriptText, args).asInstanceOf[CompoundBody]
+ },
+ condition = "UNBOUND_SQL_PARAMETER",
+ parameters = Map("name" -> "_16"),
+ context = ExpectedContext(
+ fragment = "?",
+ start = 16,
+ stop = 16))
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index baad5702f4f2..4874ea3d2795 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -39,6 +39,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
extends SingleStatementExec(
parsedPlan = Project(Seq(Alias(Literal(condVal), description)()),
OneRowRelation()),
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
+ Map.empty,
isInternal = false)
case class DummyLogicalPlan() extends LeafNode {
@@ -50,6 +51,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
extends SingleStatementExec(
parsedPlan = DummyLogicalPlan(),
Origin(startIndex = Some(0), stopIndex = Some(description.length)),
+ Map.empty,
isInternal = false)
class LoopBooleanConditionEvaluator(condition: TestLoopCondition) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
similarity index 60%
copy from
sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
copy to
sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index b0b844d2b52c..bbeae942f9fe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.scripting
-import org.apache.spark.{SparkConf, SparkException, SparkNumberFormatException}
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest,
Row}
-import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
-import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -31,7 +30,7 @@ import org.apache.spark.sql.test.SharedSparkSession
* Output from the interpreter (iterator over executable statements) is then
checked - statements
* are executed and output DataFrames are compared with expected outputs.
*/
-class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
+class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
// Tests setup
override protected def sparkConf: SparkConf = {
@@ -39,25 +38,21 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
}
// Helpers
- private def runSqlScript(sqlText: String): Array[DataFrame] = {
- val interpreter = SqlScriptingInterpreter()
+ private def runSqlScript(
+ sqlText: String,
+ args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = {
val compoundBody =
spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
- val executionPlan = interpreter.buildExecutionPlan(compoundBody, spark)
- executionPlan.flatMap {
- case statement: SingleStatementExec =>
- if (statement.isExecuted) {
- None
- } else {
- Some(Dataset.ofRows(spark, statement.parsedPlan, new
QueryPlanningTracker))
- }
- case _ => None
- }.toArray
+ val sse = new SqlScriptingExecution(compoundBody, spark, args)
+ sse.map { df => df.collect() }.toList
}
private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]):
Unit = {
val result = runSqlScript(sqlText)
assert(result.length == expected.length)
- result.zip(expected).foreach { case (df, expectedAnswer) =>
checkAnswer(df, expectedAnswer) }
+ result.zip(expected).foreach {
+ case (actualAnswer, expectedAnswer) =>
+ assert(actualAnswer.sameElements(expectedAnswer))
+ }
}
// Tests
@@ -73,9 +68,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|END
|""".stripMargin
val expected = Seq(
- Seq.empty[Row], // create table
- Seq.empty[Row], // insert
- Seq.empty[Row], // select with filter
+ Seq.empty[Row], // select
Seq(Row(1)) // select
)
verifySqlScriptResult(sqlScript, expected)
@@ -97,12 +90,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|FROM t;
|END
|""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // create table
- Seq.empty[Row], // insert #1
- Seq.empty[Row], // insert #2
- Seq(Row(false)) // select
- )
+ val expected = Seq(Seq(Row(false)))
verifySqlScriptResult(sqlScript, expected)
}
}
@@ -116,12 +104,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|SELECT var;
|END
|""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare var
- Seq.empty[Row], // set var
- Seq(Row(2)), // select
- Seq.empty[Row] // drop var
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScript, expected)
}
@@ -134,12 +117,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|SELECT var;
|END
|""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare var
- Seq.empty[Row], // set var
- Seq(Row(2)), // select
- Seq.empty[Row] // drop var
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScript, expected)
}
@@ -163,47 +141,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|END
|""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare var
Seq(Row(1)), // select
- Seq.empty[Row], // drop var
- Seq.empty[Row], // declare var
Seq(Row(2)), // select
- Seq.empty[Row], // drop var
- Seq.empty[Row], // declare var
- Seq.empty[Row], // set var
- Seq(Row(4)), // select
- Seq.empty[Row] // drop var
+ Seq(Row(4)) // select
)
verifySqlScriptResult(sqlScript, expected)
}
- test("session vars - var out of scope") {
- val varName: String = "testVarName"
- val e = intercept[AnalysisException] {
- val sqlScript =
- s"""
- |BEGIN
- | BEGIN
- | DECLARE $varName = 1;
- | SELECT $varName;
- | END;
- | SELECT $varName;
- |END
- |""".stripMargin
- verifySqlScriptResult(sqlScript, Seq.empty)
- }
- checkError(
- exception = e,
- condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION",
- sqlState = "42703",
- parameters = Map("objectName" -> s"`$varName`"),
- context = ExpectedContext(
- fragment = s"$varName",
- start = 79,
- stop = 89)
- )
- }
-
test("session vars - drop var statement") {
val sqlScript =
"""
@@ -214,13 +158,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|DROP TEMPORARY VARIABLE var;
|END
|""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare var
- Seq.empty[Row], // set var
- Seq(Row(2)), // select
- Seq.empty[Row], // drop var - explicit
- Seq.empty[Row] // drop var - implicit
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScript, expected)
}
@@ -266,7 +204,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
val expected = Seq(Seq(Row(42)))
verifySqlScriptResult(commands, expected)
}
@@ -286,7 +223,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
@@ -303,7 +239,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
@@ -323,7 +258,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
val expected = Seq(Seq(Row(44)))
verifySqlScriptResult(commands, expected)
}
@@ -343,8 +277,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
}
@@ -366,8 +299,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END IF;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
}
@@ -457,8 +389,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
}
@@ -481,8 +412,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(43)))
+ val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
}
@@ -499,79 +429,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
- val expected = Seq()
+ val expected = Seq.empty
verifySqlScriptResult(commands, expected)
}
- test("searched case when evaluates to null") {
- withTable("t") {
- val commands =
- """
- |BEGIN
- | CREATE TABLE t (a BOOLEAN) USING parquet;
- | CASE
- | WHEN (SELECT * FROM t) THEN
- | SELECT 42;
- | END CASE;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SqlScriptingException] (
- runSqlScript(commands)
- ),
- condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
- parameters = Map("invalidStatement" -> "(SELECT * FROM T)")
- )
- }
- }
-
- test("searched case with non boolean condition - constant") {
- val commands =
- """
- |BEGIN
- | CASE
- | WHEN 1 THEN
- | SELECT 42;
- | END CASE;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SqlScriptingException] (
- runSqlScript(commands)
- ),
- condition = "INVALID_BOOLEAN_STATEMENT",
- parameters = Map("invalidStatement" -> "1")
- )
- }
-
- test("searched case with too many rows in subquery condition") {
- withTable("t") {
- val commands =
- """
- |BEGIN
- | CREATE TABLE t (a BOOLEAN) USING parquet;
- | INSERT INTO t VALUES (true);
- | INSERT INTO t VALUES (true);
- | CASE
- | WHEN (SELECT * FROM t) THEN
- | SELECT 1;
- | END CASE;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SparkException] (
- runSqlScript(commands)
- ),
- condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
- parameters = Map.empty,
- context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 124,
stop = 140)
- )
- }
- }
-
test("simple case") {
val commands =
"""
@@ -659,8 +520,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(42)))
+ val expected = Seq(Seq(Row(42)))
verifySqlScriptResult(commands, expected)
}
}
@@ -683,8 +543,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq.empty[Row], Seq.empty[Row],
Seq(Row(44)))
+ val expected = Seq(Seq(Row(44)))
verifySqlScriptResult(commands, expected)
}
}
@@ -701,42 +560,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
- val expected = Seq()
+ val expected = Seq.empty
verifySqlScriptResult(commands, expected)
}
- test("simple case mismatched types") {
- val commands =
- """
- |BEGIN
- | CASE 1
- | WHEN "one" THEN
- | SELECT 42;
- | END CASE;
- |END
- |""".stripMargin
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
- checkError(
- exception = intercept[SparkNumberFormatException](
- runSqlScript(commands)
- ),
- condition = "CAST_INVALID_INPUT",
- parameters = Map(
- "expression" -> "'one'",
- "sourceType" -> "\"STRING\"",
- "targetType" -> "\"BIGINT\""),
- context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27))
- }
- withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
- checkError(
- exception = intercept[SqlScriptingException](
- runSqlScript(commands)
- ),
- condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
- parameters = Map("invalidStatement" -> "\"ONE\""))
- }
- }
-
test("simple case compare with null") {
withTable("t") {
val commands =
@@ -751,81 +578,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END CASE;
|END
|""".stripMargin
-
- val expected = Seq(Seq.empty[Row], Seq(Row(43)))
+ val expected = Seq(Seq(Row(43)))
verifySqlScriptResult(commands, expected)
}
}
- test("if's condition must be a boolean statement") {
- withTable("t") {
- val commands =
- """
- |BEGIN
- | IF 1 THEN
- | SELECT 45;
- | END IF;
- |END
- |""".stripMargin
- val exception = intercept[SqlScriptingException] {
- runSqlScript(commands)
- }
- checkError(
- exception = exception,
- condition = "INVALID_BOOLEAN_STATEMENT",
- parameters = Map("invalidStatement" -> "1")
- )
- assert(exception.origin.line.isDefined)
- assert(exception.origin.line.get == 3)
- }
- }
-
- test("if's condition must return a single row data") {
- withTable("t1", "t2") {
- // empty row
- val commands1 =
- """
- |BEGIN
- | CREATE TABLE t1 (a BOOLEAN) USING parquet;
- | IF (SELECT * FROM t1) THEN
- | SELECT 46;
- | END IF;
- |END
- |""".stripMargin
- val exception = intercept[SqlScriptingException] {
- runSqlScript(commands1)
- }
- checkError(
- exception = exception,
- condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
- parameters = Map("invalidStatement" -> "(SELECT * FROM T1)")
- )
- assert(exception.origin.line.isDefined)
- assert(exception.origin.line.get == 4)
-
- // too many rows ( > 1 )
- val commands2 =
- """
- |BEGIN
- | CREATE TABLE t2 (a BOOLEAN) USING parquet;
- | INSERT INTO t2 VALUES (true);
- | INSERT INTO t2 VALUES (true);
- | IF (SELECT * FROM t2) THEN
- | SELECT 46;
- | END IF;
- |END
- |""".stripMargin
- checkError(
- exception = intercept[SparkException] (
- runSqlScript(commands2)
- ),
- condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
- parameters = Map.empty,
- context = ExpectedContext(fragment = "(SELECT * FROM t2)", start =
121, stop = 138)
- )
- }
- }
-
test("while") {
val commands =
"""
@@ -837,16 +594,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END WHILE;
|END
|""".stripMargin
-
val expected = Seq(
- Seq.empty[Row], // declare i
Seq(Row(0)), // select i
- Seq.empty[Row], // set i
Seq(Row(1)), // select i
- Seq.empty[Row], // set i
- Seq(Row(2)), // select i
- Seq.empty[Row], // set i
- Seq.empty[Row] // drop var
+ Seq(Row(2)) // select i
)
verifySqlScriptResult(commands, expected)
}
@@ -862,11 +613,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END WHILE;
|END
|""".stripMargin
-
- val expected = Seq(
- Seq.empty[Row], // declare i
- Seq.empty[Row] // drop i
- )
+ val expected = Seq.empty
verifySqlScriptResult(commands, expected)
}
@@ -886,24 +633,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END WHILE;
|END
|""".stripMargin
-
val expected = Seq(
- Seq.empty[Row], // declare i
- Seq.empty[Row], // declare j
- Seq.empty[Row], // set j to 0
Seq(Row(0, 0)), // select i, j
- Seq.empty[Row], // increase j
Seq(Row(0, 1)), // select i, j
- Seq.empty[Row], // increase j
- Seq.empty[Row], // increase i
- Seq.empty[Row], // set j to 0
Seq(Row(1, 0)), // select i, j
- Seq.empty[Row], // increase j
- Seq(Row(1, 1)), // select i, j
- Seq.empty[Row], // increase j
- Seq.empty[Row], // increase i
- Seq.empty[Row], // drop j
- Seq.empty[Row] // drop i
+ Seq(Row(1, 1)) // select i, j
)
verifySqlScriptResult(commands, expected)
}
@@ -920,13 +654,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|END WHILE;
|END
|""".stripMargin
-
val expected = Seq(
- Seq.empty[Row], // create table
- Seq(Row(42)), // select
- Seq.empty[Row], // insert
Seq(Row(42)), // select
- Seq.empty[Row] // insert
+ Seq(Row(42)) // select
)
verifySqlScriptResult(commands, expected)
}
@@ -945,16 +675,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END REPEAT;
|END
|""".stripMargin
-
val expected = Seq(
- Seq.empty[Row], // declare i
Seq(Row(0)), // select i
- Seq.empty[Row], // set i
Seq(Row(1)), // select i
- Seq.empty[Row], // set i
- Seq(Row(2)), // select i
- Seq.empty[Row], // set i
- Seq.empty[Row] // drop var
+ Seq(Row(2)) // select i
)
verifySqlScriptResult(commands, expected)
}
@@ -973,12 +697,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
|END
|""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare i
- Seq(Row(3)), // select i
- Seq.empty[Row], // set i
- Seq.empty[Row] // drop i
- )
+ val expected = Seq(Seq(Row(3)))
verifySqlScriptResult(commands, expected)
}
@@ -1002,22 +721,10 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
|""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare i
- Seq.empty[Row], // declare j
- Seq.empty[Row], // set j to 0
Seq(Row(0, 0)), // select i, j
- Seq.empty[Row], // increase j
Seq(Row(0, 1)), // select i, j
- Seq.empty[Row], // increase j
- Seq.empty[Row], // increase i
- Seq.empty[Row], // set j to 0
Seq(Row(1, 0)), // select i, j
- Seq.empty[Row], // increase j
- Seq(Row(1, 1)), // select i, j
- Seq.empty[Row], // increase j
- Seq.empty[Row], // increase i
- Seq.empty[Row], // drop j
- Seq.empty[Row] // drop i
+ Seq(Row(1, 1)) // select i, j
)
verifySqlScriptResult(commands, expected)
}
@@ -1037,90 +744,13 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
|""".stripMargin
val expected = Seq(
- Seq.empty[Row], // create table
- Seq(Row(42)), // select
- Seq.empty[Row], // insert
Seq(Row(42)), // select
- Seq.empty[Row] // insert
+ Seq(Row(42)) // select
)
verifySqlScriptResult(commands, expected)
}
}
- test("repeat with non boolean condition - constant") {
- val commands =
- """
- |BEGIN
- | DECLARE i = 0;
- | REPEAT
- | SELECT i;
- | SET VAR i = i + 1;
- | UNTIL
- | 1
- | END REPEAT;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SqlScriptingException] (
- runSqlScript(commands)
- ),
- condition = "INVALID_BOOLEAN_STATEMENT",
- parameters = Map("invalidStatement" -> "1")
- )
- }
-
- test("repeat with empty subquery condition") {
- withTable("t") {
- val commands =
- """
- |BEGIN
- | CREATE TABLE t (a BOOLEAN) USING parquet;
- | REPEAT
- | SELECT 1;
- | UNTIL
- | (SELECT * FROM t)
- | END REPEAT;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SqlScriptingException] (
- runSqlScript(commands)
- ),
- condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
- parameters = Map("invalidStatement" -> "(SELECT * FROM T)")
- )
- }
- }
-
- test("repeat with too many rows in subquery condition") {
- withTable("t") {
- val commands =
- """
- |BEGIN
- | CREATE TABLE t (a BOOLEAN) USING parquet;
- | INSERT INTO t VALUES (true);
- | INSERT INTO t VALUES (true);
- | REPEAT
- | SELECT 1;
- | UNTIL
- | (SELECT * FROM t)
- | END REPEAT;
- |END
- |""".stripMargin
-
- checkError(
- exception = intercept[SparkException] (
- runSqlScript(commands)
- ),
- condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
- parameters = Map.empty,
- context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 141,
stop = 157)
- )
- }
- }
-
test("leave compound block") {
val sqlScriptText =
"""
@@ -1131,9 +761,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| SELECT 2;
| END;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select
- )
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1146,9 +774,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| LEAVE lbl;
| END WHILE;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select
- )
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1162,29 +788,10 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| UNTIL 1 = 2
| END REPEAT;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select 1
- )
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
- test("iterate compound block - should fail") {
- val sqlScriptText =
- """
- |BEGIN
- | lbl: BEGIN
- | SELECT 1;
- | ITERATE lbl;
- | END;
- |END""".stripMargin
- checkError(
- exception = intercept[SqlScriptingException] {
- runSqlScript(sqlScriptText)
- },
- condition = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND",
- parameters = Map("labelName" -> "LBL"))
- }
-
test("iterate while loop") {
val sqlScriptText =
"""
@@ -1198,14 +805,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END WHILE;
| SELECT x;
|END""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
- Seq.empty[Row], // set x = 2
- Seq(Row(2)), // select
- Seq.empty[Row] // drop
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1223,51 +823,10 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| END REPEAT;
| SELECT x;
|END""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
- Seq.empty[Row], // set x = 2
- Seq(Row(2)), // select x
- Seq.empty[Row] // drop
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScriptText, expected)
}
- test("leave with wrong label - should fail") {
- val sqlScriptText =
- """
- |BEGIN
- | lbl: BEGIN
- | SELECT 1;
- | LEAVE randomlbl;
- | END;
- |END""".stripMargin
- checkError(
- exception = intercept[SqlScriptingException] {
- runSqlScript(sqlScriptText)
- },
- condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
- parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE"))
- }
-
- test("iterate with wrong label - should fail") {
- val sqlScriptText =
- """
- |BEGIN
- | lbl: BEGIN
- | SELECT 1;
- | ITERATE randomlbl;
- | END;
- |END""".stripMargin
- checkError(
- exception = intercept[SqlScriptingException] {
- runSqlScript(sqlScriptText)
- },
- condition = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
- parameters = Map("labelName" -> "RANDOMLBL", "statementType" ->
"ITERATE"))
- }
-
test("leave outer loop from nested repeat loop") {
val sqlScriptText =
"""
@@ -1281,9 +840,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| UNTIL 1 = 2
| END REPEAT;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select 1
- )
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1298,9 +855,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END WHILE;
| END WHILE;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select
- )
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1320,14 +875,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| SELECT x;
|END""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
Seq(Row(1)), // select 1
- Seq.empty[Row], // set x = 2
Seq(Row(1)), // select 1
- Seq(Row(2)), // select x
- Seq.empty[Row] // drop
+ Seq(Row(2)) // select x
)
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1352,16 +902,11 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| SELECT x;
|END""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
Seq(Row(1)), // select 1
Seq(Row(2)), // select 2
- Seq.empty[Row], // set x = 2
Seq(Row(1)), // select 1
Seq(Row(2)), // select 2
- Seq(Row(2)), // select x
- Seq.empty[Row] // drop
+ Seq(Row(2)) // select x
)
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1384,14 +929,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| SELECT x;
|END""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
Seq(Row(1)), // select 1
- Seq.empty[Row], // set x = 2
Seq(Row(1)), // select 1
- Seq(Row(2)), // select x
- Seq.empty[Row] // drop
+ Seq(Row(2)) // select x
)
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1413,16 +953,10 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| SELECT x;
|END""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
Seq(Row(1)), // select x
- Seq.empty[Row], // set x = 2
Seq(Row(2)), // select x
- Seq.empty[Row], // set x = 3
- Seq(Row(3)), // select x
Seq(Row(3)), // select x
- Seq.empty[Row] // drop
+ Seq(Row(3)) // select x
)
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1451,22 +985,10 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
|""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare x
- Seq.empty[Row], // declare y
- Seq.empty[Row], // set y to 0
Seq(Row(0, 0)), // select x, y
- Seq.empty[Row], // increase y
Seq(Row(0, 1)), // select x, y
- Seq.empty[Row], // increase y
- Seq.empty[Row], // increase x
- Seq.empty[Row], // set y to 0
Seq(Row(1, 0)), // select x, y
- Seq.empty[Row], // increase y
- Seq(Row(1, 1)), // select x, y
- Seq.empty[Row], // increase y
- Seq.empty[Row], // increase x
- Seq.empty[Row], // drop y
- Seq.empty[Row] // drop x
+ Seq(Row(1, 1)) // select x, y
)
verifySqlScriptResult(commands, expected)
}
@@ -1487,14 +1009,7 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| END LOOP;
| SELECT x;
|END""".stripMargin
- val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
- Seq.empty[Row], // set x = 2
- Seq(Row(2)), // select x
- Seq.empty[Row] // drop
- )
+ val expected = Seq(Seq(Row(2)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1509,9 +1024,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
| END LOOP;
| END LOOP;
|END""".stripMargin
- val expected = Seq(
- Seq(Row(1)) // select 1
- )
+ // Execution immediately leaves the outer loop after SELECT,
+ // so we expect only a single row in the result set.
+ val expected = Seq(Seq(Row(1)))
verifySqlScriptResult(sqlScriptText, expected)
}
@@ -1535,15 +1050,9 @@ class SqlScriptingInterpreterSuite extends QueryTest
with SharedSparkSession {
| SELECT x;
|END""".stripMargin
val expected = Seq(
- Seq.empty[Row], // declare
- Seq.empty[Row], // set x = 0
- Seq.empty[Row], // set x = 1
Seq(Row(1)), // select 1
- Seq.empty[Row], // set x = 2
Seq(Row(1)), // select 1
- Seq.empty[Row], // set x = 3
- Seq(Row(3)), // select x
- Seq.empty[Row] // drop
+ Seq(Row(3)) // select x
)
verifySqlScriptResult(sqlScriptText, expected)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index b0b844d2b52c..177ffc24d180 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.{SparkConf, SparkException, SparkNumberFormatException}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest,
Row}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.CompoundBody
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
@@ -39,10 +40,12 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
}
// Helpers
- private def runSqlScript(sqlText: String): Array[DataFrame] = {
- val interpreter = SqlScriptingInterpreter()
+ private def runSqlScript(
+ sqlText: String,
+ args: Map[String, Expression] = Map.empty): Array[DataFrame] = {
+ val interpreter = SqlScriptingInterpreter(spark)
val compoundBody =
spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
- val executionPlan = interpreter.buildExecutionPlan(compoundBody, spark)
+ val executionPlan = interpreter.buildExecutionPlan(compoundBody, args)
executionPlan.flatMap {
case statement: SingleStatementExec =>
if (statement.isExecuted) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]