This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9676b1c48cba [SPARK-48348][SPARK-48376][SQL] Introduce `LEAVE` and
`ITERATE` statements
9676b1c48cba is described below
commit 9676b1c48cba47825ff3dd48e609fa3f0b046c02
Author: David Milicevic <[email protected]>
AuthorDate: Thu Sep 5 20:59:16 2024 +0800
[SPARK-48348][SPARK-48376][SQL] Introduce `LEAVE` and `ITERATE` statements
### What changes were proposed in this pull request?
This PR proposes introduction of `LEAVE` and `ITERATE` statement types to
SQL Scripting language:
- `LEAVE` statement can be used in loops, as well as in `BEGIN ... END`
compound blocks.
- `ITERATE` statement can be used only in loops.
This PR introduces:
- Logical operators for both statement types.
- Execution nodes for both statement types.
- Interpreter changes required to build execution plans that support new
statement types.
- New error if statements are not used properly.
- Minor changes required to support new keywords.
### Why are the changes needed?
Adding support for new statement types to SQL Scripting language.
### Does this PR introduce _any_ user-facing change?
This PR introduces new statement types that will be available to users.
However, script execution logic hasn't been done yet, so the new changes are
not accessible by users yet.
### How was this patch tested?
Tests are introduced to all test suites related to SQL scripting.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47973 from davidm-db/sql_scripting_leave_iterate.
Authored-by: David Milicevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 18 +++
docs/sql-ref-ansi-compliance.md | 2 +
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 2 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 14 ++
.../spark/sql/catalyst/parser/AstBuilder.scala | 52 +++++-
.../parser/SqlScriptingLogicalOperators.scala | 18 +++
.../spark/sql/errors/SqlScriptingErrors.scala | 23 +++
.../catalyst/parser/SqlScriptingParserSuite.scala | 177 +++++++++++++++++++++
.../sql/scripting/SqlScriptingExecutionNode.scala | 104 +++++++++++-
.../sql/scripting/SqlScriptingInterpreter.scala | 17 +-
.../sql-tests/results/ansi/keywords.sql.out | 2 +
.../resources/sql-tests/results/keywords.sql.out | 2 +
.../scripting/SqlScriptingExecutionNodeSuite.scala | 103 +++++++++++-
.../scripting/SqlScriptingInterpreterSuite.scala | 143 +++++++++++++++++
.../ThriftServerWithSparkContextSuite.scala | 2 +-
15 files changed, 664 insertions(+), 15 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 96105c967225..b42aae1311f4 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2495,6 +2495,24 @@
],
"sqlState" : "F0000"
},
+ "INVALID_LABEL_USAGE" : {
+ "message" : [
+ "The usage of the label <labelName> is invalid."
+ ],
+ "subClass" : {
+ "DOES_NOT_EXIST" : {
+ "message" : [
+ "Label was used in the <statementType> statement, but the label does
not belong to any surrounding block."
+ ]
+ },
+ "ITERATE_IN_COMPOUND" : {
+ "message" : [
+ "ITERATE statement cannot be used with a label that belongs to a
compound (BEGIN...END) body."
+ ]
+ }
+ },
+ "sqlState" : "42K0L"
+ },
"INVALID_LAMBDA_FUNCTION_CALL" : {
"message" : [
"Invalid lambda function call."
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index f5e1ddfd3c57..0ac19e2ae943 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -556,6 +556,7 @@ Below is a list of all the keywords in Spark SQL.
|INVOKER|non-reserved|non-reserved|non-reserved|
|IS|reserved|non-reserved|reserved|
|ITEMS|non-reserved|non-reserved|non-reserved|
+|ITERATE|non-reserved|non-reserved|non-reserved|
|JOIN|reserved|strict-non-reserved|reserved|
|KEYS|non-reserved|non-reserved|non-reserved|
|LANGUAGE|non-reserved|non-reserved|reserved|
@@ -563,6 +564,7 @@ Below is a list of all the keywords in Spark SQL.
|LATERAL|reserved|strict-non-reserved|reserved|
|LAZY|non-reserved|non-reserved|non-reserved|
|LEADING|reserved|non-reserved|reserved|
+|LEAVE|non-reserved|non-reserved|non-reserved|
|LEFT|reserved|strict-non-reserved|reserved|
|LIKE|non-reserved|non-reserved|reserved|
|ILIKE|non-reserved|non-reserved|non-reserved|
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index acfc0011f5d0..6793cb46852b 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -276,6 +276,7 @@ INTO: 'INTO';
INVOKER: 'INVOKER';
IS: 'IS';
ITEMS: 'ITEMS';
+ITERATE: 'ITERATE';
JOIN: 'JOIN';
KEYS: 'KEYS';
LANGUAGE: 'LANGUAGE';
@@ -283,6 +284,7 @@ LAST: 'LAST';
LATERAL: 'LATERAL';
LAZY: 'LAZY';
LEADING: 'LEADING';
+LEAVE: 'LEAVE';
LEFT: 'LEFT';
LIKE: 'LIKE';
ILIKE: 'ILIKE';
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 5b8805821b04..6a23bd394c8c 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -65,6 +65,8 @@ compoundStatement
| beginEndCompoundBlock
| ifElseStatement
| whileStatement
+ | leaveStatement
+ | iterateStatement
;
setStatementWithOptionalVarKeyword
@@ -83,6 +85,14 @@ ifElseStatement
(ELSE elseBody=compoundBody)? END IF
;
+leaveStatement
+ : LEAVE multipartIdentifier
+ ;
+
+iterateStatement
+ : ITERATE multipartIdentifier
+ ;
+
singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
@@ -1578,10 +1588,12 @@ ansiNonReserved
| INTERVAL
| INVOKER
| ITEMS
+ | ITERATE
| KEYS
| LANGUAGE
| LAST
| LAZY
+ | LEAVE
| LIKE
| ILIKE
| LIMIT
@@ -1927,11 +1939,13 @@ nonReserved
| INVOKER
| IS
| ITEMS
+ | ITERATE
| KEYS
| LANGUAGE
| LAST
| LAZY
| LEADING
+ | LEAVE
| LIKE
| LONG
| ILIKE
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index b0922542c562..f4638920af3c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer, Set}
import scala.jdk.CollectionConverters._
import scala.util.{Left, Right}
-import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.antlr.v4.runtime.{ParserRuleContext, RuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
@@ -261,6 +261,56 @@ class AstBuilder extends DataTypeAstBuilder
WhileStatement(condition, body, Some(labelText))
}
+ private def leaveOrIterateContextHasLabel(
+ ctx: RuleContext, label: String, isIterate: Boolean): Boolean = {
+ ctx match {
+ case c: BeginEndCompoundBlockContext
+ if Option(c.beginLabel()).isDefined &&
+
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
=>
+ if (isIterate) {
+ throw
SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label)
+ }
+ true
+ case c: WhileStatementContext
+ if Option(c.beginLabel()).isDefined &&
+
c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ => true
+ case _ => false
+ }
+ }
+
+ override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement
=
+ withOrigin(ctx) {
+ val labelText =
ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
+ var parentCtx = ctx.parent
+
+ while (Option(parentCtx).isDefined) {
+ if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate =
false)) {
+ return LeaveStatement(labelText)
+ }
+ parentCtx = parentCtx.parent
+ }
+
+ throw SqlScriptingErrors.labelDoesNotExist(
+ CurrentOrigin.get, labelText, "LEAVE")
+ }
+
+ override def visitIterateStatement(ctx: IterateStatementContext):
IterateStatement =
+ withOrigin(ctx) {
+ val labelText =
ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
+ var parentCtx = ctx.parent
+
+ while (Option(parentCtx).isDefined) {
+ if (leaveOrIterateContextHasLabel(parentCtx, labelText, isIterate =
true)) {
+ return IterateStatement(labelText)
+ }
+ parentCtx = parentCtx.parent
+ }
+
+ throw SqlScriptingErrors.labelDoesNotExist(
+ CurrentOrigin.get, labelText, "ITERATE")
+ }
+
override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan
= withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
index 4a5259f09a8a..dbb29a71323e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala
@@ -89,3 +89,21 @@ case class WhileStatement(
condition: SingleStatement,
body: CompoundBody,
label: Option[String]) extends CompoundPlanStatement
+
+/**
+ * Logical operator for LEAVE statement.
+ * The statement can be used both for compounds or any kind of loops.
+ * When used, the corresponding body/loop execution is skipped and the
execution continues
+ * with the next statement after the body/loop.
+ * @param label Label of the compound or loop to leave.
+ */
+case class LeaveStatement(label: String) extends CompoundPlanStatement
+
+/**
+ * Logical operator for ITERATE statement.
+ * The statement can be used only for loops.
+ * When used, the rest of the loop is skipped and the loop execution continues
+ * with the next iteration.
+ * @param label Label of the loop to iterate.
+ */
+case class IterateStatement(label: String) extends CompoundPlanStatement
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
index 61661b1d32f3..591d2e3e53d4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/SqlScriptingErrors.scala
@@ -84,4 +84,27 @@ private[sql] object SqlScriptingErrors {
cause = null,
messageParameters = Map("invalidStatement" -> toSQLStmt(stmt)))
}
+
+ def labelDoesNotExist(
+ origin: Origin,
+ labelName: String,
+ statementType: String): Throwable = {
+ new SqlScriptingException(
+ origin = origin,
+ errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
+ cause = null,
+ messageParameters = Map(
+ "labelName" -> toSQLStmt(labelName),
+ "statementType" -> statementType))
+ }
+
+ def invalidIterateLabelUsageForCompound(
+ origin: Origin,
+ labelName: String): Throwable = {
+ new SqlScriptingException(
+ origin = origin,
+ errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND",
+ cause = null,
+ messageParameters = Map("labelName" -> toSQLStmt(labelName)))
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 5fc3ade408bd..465c2d408f26 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -666,7 +666,184 @@ class SqlScriptingParserSuite extends SparkFunSuite with
SQLHelper {
head.asInstanceOf[SingleStatement].getText == "SELECT 42")
assert(whileStmt.label.contains("lbl"))
+ }
+
+ test("leave compound block") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | LEAVE lbl;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 2)
+ assert(tree.collection.head.isInstanceOf[SingleStatement])
+ assert(tree.collection(1).isInstanceOf[LeaveStatement])
+ }
+
+ test("leave while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | SELECT 1;
+ | LEAVE lbl;
+ | END WHILE;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[WhileStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 2)
+
+ assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(whileStmt.body.collection(1).isInstanceOf[LeaveStatement])
+ assert(whileStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
+ }
+
+ test ("iterate compound block - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | ITERATE lbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND",
+ parameters = Map("labelName" -> "LBL"))
+ }
+
+ test("iterate while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | SELECT 1;
+ | ITERATE lbl;
+ | END WHILE;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[WhileStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 2)
+
+ assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText ==
"SELECT 1")
+
+ assert(whileStmt.body.collection(1).isInstanceOf[IterateStatement])
+ assert(whileStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
+ }
+
+ test("leave with wrong label - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | LEAVE randomlbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
+ parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE"))
+ }
+
+ test("iterate with wrong label - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | ITERATE randomlbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
+ parameters = Map("labelName" -> "RANDOMLBL", "statementType" ->
"ITERATE"))
+ }
+
+ test("leave outer loop from nested while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | lbl2: WHILE 2 = 2 DO
+ | SELECT 1;
+ | LEAVE lbl;
+ | END WHILE;
+ | END WHILE;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[WhileStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 1)
+
+ val nestedWhileStmt =
whileStmt.body.collection.head.asInstanceOf[WhileStatement]
+ assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement])
+ assert(nestedWhileStmt.condition.getText == "2 = 2")
+
+ assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText
== "SELECT 1")
+
+ assert(nestedWhileStmt.body.collection(1).isInstanceOf[LeaveStatement])
+
assert(nestedWhileStmt.body.collection(1).asInstanceOf[LeaveStatement].label ==
"lbl")
+ }
+
+ test("iterate outer loop from nested while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | lbl2: WHILE 2 = 2 DO
+ | SELECT 1;
+ | ITERATE lbl;
+ | END WHILE;
+ | END WHILE;
+ |END""".stripMargin
+ val tree = parseScript(sqlScriptText)
+ assert(tree.collection.length == 1)
+ assert(tree.collection.head.isInstanceOf[WhileStatement])
+
+ val whileStmt = tree.collection.head.asInstanceOf[WhileStatement]
+ assert(whileStmt.condition.isInstanceOf[SingleStatement])
+ assert(whileStmt.condition.getText == "1 = 1")
+
+ assert(whileStmt.body.isInstanceOf[CompoundBody])
+ assert(whileStmt.body.collection.length == 1)
+
+ val nestedWhileStmt =
whileStmt.body.collection.head.asInstanceOf[WhileStatement]
+ assert(nestedWhileStmt.condition.isInstanceOf[SingleStatement])
+ assert(nestedWhileStmt.condition.getText == "2 = 2")
+
+ assert(nestedWhileStmt.body.collection.head.isInstanceOf[SingleStatement])
+
assert(nestedWhileStmt.body.collection.head.asInstanceOf[SingleStatement].getText
== "SELECT 1")
+ assert(nestedWhileStmt.body.collection(1).isInstanceOf[IterateStatement])
+
assert(nestedWhileStmt.body.collection(1).asInstanceOf[IterateStatement].label
== "lbl")
}
// Helper methods
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 7085366c3b7a..c2e6abf184b5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -140,13 +140,20 @@ class SingleStatementExec(
* Implements recursive iterator logic over all child execution nodes.
* @param collection
* Collection of child execution nodes.
+ * @param label
+ * Label set by user or None otherwise.
*/
-abstract class CompoundNestedStatementIteratorExec(collection:
Seq[CompoundStatementExec])
+abstract class CompoundNestedStatementIteratorExec(
+ collection: Seq[CompoundStatementExec],
+ label: Option[String] = None)
extends NonLeafStatementExec {
private var localIterator = collection.iterator
private var curr = if (localIterator.hasNext) Some(localIterator.next())
else None
+ /** Used to stop the iteration in cases when LEAVE statement is encountered.
*/
+ private var stopIteration = false
+
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean = {
@@ -157,7 +164,7 @@ abstract class
CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
case _ => throw SparkException.internalError(
"Unknown statement type encountered during SQL script
interpretation.")
}
- localIterator.hasNext || childHasNext
+ !stopIteration && (localIterator.hasNext || childHasNext)
}
@scala.annotation.tailrec
@@ -165,12 +172,21 @@ abstract class
CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
curr match {
case None => throw SparkException.internalError(
"No more elements to iterate through in the current SQL compound
statement.")
+ case Some(leaveStatement: LeaveStatementExec) =>
+ handleLeaveStatement(leaveStatement)
+ curr = None
+ leaveStatement
case Some(statement: LeafStatementExec) =>
curr = if (localIterator.hasNext) Some(localIterator.next()) else
None
statement
case Some(body: NonLeafStatementExec) =>
if (body.getTreeIterator.hasNext) {
- body.getTreeIterator.next()
+ body.getTreeIterator.next() match {
+ case leaveStatement: LeaveStatementExec =>
+ handleLeaveStatement(leaveStatement)
+ leaveStatement
+ case other => other
+ }
} else {
curr = if (localIterator.hasNext) Some(localIterator.next())
else None
next()
@@ -187,6 +203,20 @@ abstract class
CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
collection.foreach(_.reset())
localIterator = collection.iterator
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
+ stopIteration = false
+ }
+
+ /** Actions to do when LEAVE statement is encountered to stop the execution
of this compound. */
+ private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit =
{
+ if (!leaveStatement.hasBeenMatched) {
+ // Stop the iteration.
+ stopIteration = true
+
+ // TODO: Variable cleanup (once we add SQL script execution logic).
+
+ // Check if label has been matched.
+ leaveStatement.hasBeenMatched = label.isDefined &&
label.get.equals(leaveStatement.label)
+ }
}
}
@@ -194,9 +224,11 @@ abstract class
CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
* Executable node for CompoundBody.
* @param statements
* Executable nodes for nested statements within the CompoundBody.
+ * @param label
+ * Label set by user to CompoundBody or None otherwise.
*/
-class CompoundBodyExec(statements: Seq[CompoundStatementExec])
- extends CompoundNestedStatementIteratorExec(statements)
+class CompoundBodyExec(statements: Seq[CompoundStatementExec], label:
Option[String] = None)
+ extends CompoundNestedStatementIteratorExec(statements, label)
/**
* Executable node for IfElseStatement.
@@ -277,11 +309,13 @@ class IfElseStatementExec(
* Executable node for WhileStatement.
* @param condition Executable node for the condition.
* @param body Executable node for the body.
+ * @param label Label set to WhileStatement by user or None otherwise.
* @param session Spark session that SQL script is executed within.
*/
class WhileStatementExec(
condition: SingleStatementExec,
body: CompoundBodyExec,
+ label: Option[String],
session: SparkSession) extends NonLeafStatementExec {
private object WhileState extends Enumeration {
@@ -308,6 +342,26 @@ class WhileStatementExec(
condition
case WhileState.Body =>
val retStmt = body.getTreeIterator.next()
+
+ // Handle LEAVE or ITERATE statement if it has been encountered.
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if
!leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ curr = None
+ return retStmt
+ case iterStatementExec: IterateStatementExec if
!iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ }
+ state = WhileState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return retStmt
+ case _ =>
+ }
+
if (!body.getTreeIterator.hasNext) {
state = WhileState.Condition
curr = Some(condition)
@@ -326,3 +380,43 @@ class WhileStatementExec(
body.reset()
}
}
+
+/**
+ * Executable node for LeaveStatement.
+ * @param label Label of the compound or loop to leave.
+ */
+class LeaveStatementExec(val label: String) extends LeafStatementExec {
+ /**
+ * Label specified in the LEAVE statement might not belong to the immediate
surrounding compound,
+ * but to the any surrounding compound.
+ * Iteration logic is recursive, i.e. when iterating through the compound,
if another
+ * compound is encountered, next() will be called to iterate its body. The
same logic
+ * is applied to any other compound down the traversal tree.
+ * In such cases, when LEAVE statement is encountered (as the leaf of the
traversal tree),
+ * it will be propagated upwards and the logic will try to match it to the
labels of
+ * surrounding compounds.
+ * Once the match is found, this flag is set to true to indicate that search
should be stopped.
+ */
+ var hasBeenMatched: Boolean = false
+ override def reset(): Unit = hasBeenMatched = false
+}
+
+/**
+ * Executable node for ITERATE statement.
+ * @param label Label of the loop to iterate.
+ */
+class IterateStatementExec(val label: String) extends LeafStatementExec {
+ /**
+ * Label specified in the ITERATE statement might not belong to the
immediate compound,
+ * but to the any surrounding compound.
+ * Iteration logic is recursive, i.e. when iterating through the compound,
if another
+ * compound is encountered, next() will be called to iterate its body. The
same logic
+ * is applied to any other compound down the tree.
+ * In such cases, when ITERATE statement is encountered (as the leaf of the
traversal tree),
+ * it will be propagated upwards and the logic will try to match it to the
labels of
+ * surrounding compounds.
+ * Once the match is found, this flag is set to true to indicate that search
should be stopped.
+ */
+ var hasBeenMatched: Boolean = false
+ override def reset(): Unit = hasBeenMatched = false
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 08b4f9728628..8a5a9774d42f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CompoundBody,
CompoundPlanStatement, IfElseStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CompoundBody,
CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement,
SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable,
DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
@@ -71,9 +71,9 @@ case class SqlScriptingInterpreter() {
private def transformTreeIntoExecutable(
node: CompoundPlanStatement, session: SparkSession):
CompoundStatementExec =
node match {
- case body: CompoundBody =>
+ case CompoundBody(collection, label) =>
// TODO [SPARK-48530]: Current logic doesn't support scoped variables
and shadowing.
- val variables = body.collection.flatMap {
+ val variables = collection.flatMap {
case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan)
case _ => None
}
@@ -82,7 +82,8 @@ case class SqlScriptingInterpreter() {
.map(new SingleStatementExec(_, Origin(), isInternal = true))
.reverse
new CompoundBodyExec(
- body.collection.map(st => transformTreeIntoExecutable(st, session))
++ dropVariables)
+ collection.map(st => transformTreeIntoExecutable(st, session)) ++
dropVariables,
+ label)
case IfElseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false))
@@ -92,12 +93,16 @@ case class SqlScriptingInterpreter() {
transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec])
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec,
session)
- case WhileStatement(condition, body, _) =>
+ case WhileStatement(condition, body, label) =>
val conditionExec =
new SingleStatementExec(condition.parsedPlan, condition.origin,
isInternal = false)
val bodyExec =
transformTreeIntoExecutable(body,
session).asInstanceOf[CompoundBodyExec]
- new WhileStatementExec(conditionExec, bodyExec, session)
+ new WhileStatementExec(conditionExec, bodyExec, label, session)
+ case leaveStatement: LeaveStatement =>
+ new LeaveStatementExec(leaveStatement.label)
+ case iterateStatement: IterateStatement =>
+ new IterateStatementExec(iterateStatement.label)
case sparkStatement: SingleStatement =>
new SingleStatementExec(
sparkStatement.parsedPlan,
diff --git
a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index 5735e5eef68e..b2f3fdda74db 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -163,6 +163,7 @@ INTO true
INVOKER false
IS true
ITEMS false
+ITERATE false
JOIN true
KEYS false
LANGUAGE false
@@ -170,6 +171,7 @@ LAST false
LATERAL true
LAZY false
LEADING true
+LEAVE false
LEFT true
LIKE false
LIMIT false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index ca48e851e717..ce9fd580b2ff 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -163,6 +163,7 @@ INTO false
INVOKER false
IS false
ITEMS false
+ITERATE false
JOIN false
KEYS false
LANGUAGE false
@@ -170,6 +171,7 @@ LAST false
LATERAL false
LAZY false
LEADING false
+LEAVE false
LEFT false
LIKE false
LIMIT false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index 5c36f9e19e6d..97a21c505fdd 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -54,8 +54,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
case class TestWhile(
condition: TestWhileCondition,
- body: CompoundBodyExec)
- extends WhileStatementExec(condition, body, spark) {
+ body: CompoundBodyExec,
+ label: Option[String] = None)
+ extends WhileStatementExec(condition, body, label, spark) {
private var callCount: Int = 0
@@ -77,6 +78,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite
with SharedSparkSessi
case TestLeafStatement(testVal) => testVal
case TestIfElseCondition(_, description) => description
case TestWhileCondition(_, _, description) => description
+ case leaveStmt: LeaveStatementExec => leaveStmt.label
+ case iterateStmt: IterateStatementExec => iterateStmt.label
case _ => fail("Unexpected statement type")
}
@@ -314,4 +317,100 @@ class SqlScriptingExecutionNodeSuite extends
SparkFunSuite with SharedSparkSessi
"con2", "body1", "con2", "con1"))
}
+ test("leave compound block") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestLeafStatement("one"),
+ new LeaveStatementExec("lbl")
+ ),
+ label = Some("lbl")
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("one", "lbl"))
+ }
+
+ test("leave while loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl"))
+ ),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "body1", "lbl"))
+ }
+
+ test("iterate while loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl"),
+ TestLeafStatement("body2"))
+ ),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "body1", "lbl", "con1", "body1", "lbl",
"con1"))
+ }
+
+ test("leave outer loop from nested while loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2,
description = "con2"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new LeaveStatementExec("lbl"))
+ ),
+ label = Some("lbl2")
+ )
+ )),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("con1", "con2", "body1", "lbl"))
+ }
+
+ test("iterate outer loop from nested while loop") {
+ val iter = new CompoundBodyExec(
+ statements = Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2, description
= "con1"),
+ body = new CompoundBodyExec(Seq(
+ TestWhile(
+ condition = TestWhileCondition(condVal = true, reps = 2,
description = "con2"),
+ body = new CompoundBodyExec(Seq(
+ TestLeafStatement("body1"),
+ new IterateStatementExec("lbl"),
+ TestLeafStatement("body2"))
+ ),
+ label = Some("lbl2")
+ )
+ )),
+ label = Some("lbl")
+ )
+ )
+ ).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq(
+ "con1", "con2", "body1", "lbl",
+ "con1", "con2", "body1", "lbl",
+ "con1"))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 592516de84c1..5568f85fc476 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest,
Row}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseScript
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.test.SharedSparkSession
@@ -536,4 +537,146 @@ class SqlScriptingInterpreterSuite extends QueryTest with
SharedSparkSession {
verifySqlScriptResult(commands, expected)
}
}
+
+ test("leave compound block") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | LEAVE lbl;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("leave while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | SELECT 1;
+ | LEAVE lbl;
+ | END WHILE;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("iterate compound block - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | ITERATE lbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.ITERATE_IN_COMPOUND",
+ parameters = Map("labelName" -> "LBL"))
+ }
+
+ test("iterate while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: WHILE x < 2 DO
+ | SET x = x + 1;
+ | ITERATE lbl;
+ | SET x = x + 2;
+ | END WHILE;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq.empty[Row], // set x = 2
+ Seq(Row(2)), // select
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("leave with wrong label - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | LEAVE randomlbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
+ parameters = Map("labelName" -> "RANDOMLBL", "statementType" -> "LEAVE"))
+ }
+
+ test("iterate with wrong label - should fail") {
+ val sqlScriptText =
+ """
+ |lbl: BEGIN
+ | SELECT 1;
+ | ITERATE randomlbl;
+ |END""".stripMargin
+ checkError(
+ exception = intercept[SqlScriptingException] {
+ parseScript(sqlScriptText)
+ },
+ errorClass = "INVALID_LABEL_USAGE.DOES_NOT_EXIST",
+ parameters = Map("labelName" -> "RANDOMLBL", "statementType" ->
"ITERATE"))
+ }
+
+ test("leave outer loop from nested while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | lbl: WHILE 1 = 1 DO
+ | lbl2: WHILE 2 = 2 DO
+ | SELECT 1;
+ | LEAVE lbl;
+ | END WHILE;
+ | END WHILE;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq(Row(1)) // select
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
+
+ test("iterate outer loop from nested while loop") {
+ val sqlScriptText =
+ """
+ |BEGIN
+ | DECLARE x INT;
+ | SET x = 0;
+ | lbl: WHILE x < 2 DO
+ | SET x = x + 1;
+ | lbl2: WHILE 2 = 2 DO
+ | SELECT 1;
+ | ITERATE lbl;
+ | END WHILE;
+ | END WHILE;
+ | SELECT x;
+ |END""".stripMargin
+ val expected = Seq(
+ Seq.empty[Row], // declare
+ Seq.empty[Row], // set x = 0
+ Seq.empty[Row], // set x = 1
+ Seq(Row(1)), // select 1
+ Seq.empty[Row], // set x= 2
+ Seq(Row(1)), // select 1
+ Seq(Row(2)), // select x
+ Seq.empty[Row] // drop
+ )
+ verifySqlScriptResult(sqlScriptText, expected)
+ }
}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 7005f0e951b2..2e3457dab09b 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends
SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle,
GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DAT
[...]
+ assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DAT
[...]
// scalastyle:on line.size.limit
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]