This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d7043c0c7c7e [SPARK-52393][SDP] Pipeline SQL Graph Element Registration
d7043c0c7c7e is described below
commit d7043c0c7c7e71c9f3a862ea2bd434e2bc6735ee
Author: Anish Mahto <[email protected]>
AuthorDate: Fri Jun 6 20:21:40 2025 -0700
[SPARK-52393][SDP] Pipeline SQL Graph Element Registration
### What changes were proposed in this pull request?
Add functionality to register graph elements (tables, views, flows) in a
declarative pipeline's DataflowGraph object, from the SQL files sent on
`DefineSqlDataset` requests to the spark connect backend.
This involves parsing the SQL text, interpreting the extracted logical
plans, and constructing appropriate graph element objects (Table, View, Flow).
The consequence is when a pipeline is eventually run, the registered graph
elements from SQL files will actually materialize and produce the correct
streaming table/materialized view/temporary view during execution.
### Why are the changes needed?
To support the creation of Spark Declarative Pipeline objects from SQL
files.
### Does this PR introduce _any_ user-facing change?
No. The Spark Declarative Pipelines module is not yet released in any spark
version. This PR implements logic that will eventually help service the
currently unhandled `DefineSqlDataset` spark connect request.
### How was this patch tested?
`org.apache.spark.sql.pipelines.graph.SqlPipelineSuite` contains the core
unit tests for SQL graph element registration.
`org.apache.spark.sql.pipelines.graph.SqlQueryOriginSuite` contains unit
tests to verify the query origin is correctly constructed for graph elements
registered from SQL source files.
`org.apache.spark.sql.catalyst.util.StringUtilsSuite` contains unit tests
to verify SQL string splitting logic.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #51080 from AnishMahto/pipelines-sql-element-registration.
Lead-authored-by: Anish Mahto <[email protected]>
Co-authored-by: Yuheng Chang <[email protected]>
Co-authored-by: anishm-db <[email protected]>
Co-authored-by: Sandy Ryza <[email protected]>
Co-authored-by: Aakash Japi <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 38 +
.../spark/sql/catalyst/util/StringUtils.scala | 264 +++++++
.../spark/sql/catalyst/util/StringUtilsSuite.scala | 131 ++++
.../sql/pipelines/graph/GraphValidations.scala | 2 +-
.../spark/sql/pipelines/graph/QueryOrigin.scala | 12 -
.../sql/pipelines/graph/QueryOriginType.scala | 23 +
.../graph/SqlGraphRegistrationContext.scala | 674 ++++++++++++++++
.../logging/FlowProgressEventLogger.scala | 22 +-
.../sql/pipelines/logging/PipelineEvent.scala | 19 +-
.../sql/pipelines/graph/SqlPipelineSuite.scala | 855 +++++++++++++++++++++
.../sql/pipelines/graph/SqlQueryOriginSuite.scala | 173 +++++
.../logging/ConstructPipelineEventSuite.scala | 13 +-
.../spark/sql/pipelines/utils/PipelineTest.scala | 44 +-
13 files changed, 2221 insertions(+), 49 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index eacea816db13..54021063c78e 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4129,6 +4129,12 @@
],
"sqlState" : "42P20"
},
+ "MULTIPART_FLOW_NAME_NOT_SUPPORTED" : {
+ "message" : [
+ "Flow with multipart name '<flowName>' is not supported."
+ ],
+ "sqlState" : "0A000"
+ },
"MULTIPLE_PRIMARY_KEYS" : {
"message" : [
"Multiple primary keys are defined: <columns>. Please ensure that only
one primary key is defined for the table."
@@ -4614,6 +4620,38 @@
],
"sqlState" : "42K03"
},
+ "PIPELINE_DATASET_WITHOUT_FLOW" : {
+ "message" : [
+ "Pipeline dataset <identifier> does not have any defined flows. Please
attach a query with the dataset's definition, or explicitly define at least one
flow that writes to the dataset."
+ ],
+ "sqlState" : "0A000"
+ },
+ "PIPELINE_DUPLICATE_IDENTIFIERS" : {
+ "message" : [
+ "A duplicate identifier was found for elements registered in the
pipeline's dataflow graph."
+ ],
+ "subClass" : {
+ "DATASET" : {
+ "message" : [
+ "Attempted to register a <datasetType1> with identifier
<datasetName>, but a <datasetType2> has already been registered with that
identifier. Please ensure all datasets created within this pipeline have unique
identifiers."
+ ]
+ },
+ "FLOW" : {
+ "message" : [
+ "Flow <flowName> was found in multiple datasets: <datasetNames>"
+ ]
+ }
+ },
+ "sqlState" : "42710"
+ },
+ "PIPELINE_SQL_GRAPH_ELEMENT_REGISTRATION_ERROR" : {
+ "message" : [
+ "<message>",
+ "<offendingQuery>",
+ "<codeLocation>"
+ ],
+ "sqlState" : "42000"
+ },
"PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : {
"message" : [
"Non-grouping expression <expr> is provided as an argument to the |>
AGGREGATE pipe operator but does not contain any aggregate function; please
update it to include an aggregate function and then retry the query again."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index e2a5319cbe1a..b4d737dcf791 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.catalyst.util
+import java.util.Locale
import java.util.regex.{Pattern, PatternSyntaxException}
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.commons.text.similarity.LevenshteinDistance
import org.apache.spark.internal.{Logging, MDC}
@@ -150,4 +153,265 @@ object StringUtils extends Logging {
}
}
}
+
+ /**
+ * Removes comments from a SQL command. Visible for testing only.
+ * @param command The SQL command to remove comments from.
+ * @param replaceWithWhitespace If true, replaces the comment with
whitespace instead of
+ * stripping them in order to ensure query
length and character
+ * positions are preserved.
+ */
+ protected[util] def stripComment(
+ command: String, replaceWithWhitespace: Boolean = false): String = {
+ // Important characters
+ val SINGLE_QUOTE = '\''
+ val DOUBLE_QUOTE = '"'
+ val BACKTICK = '`'
+ val BACKSLASH = '\\'
+ val HYPHEN = '-'
+ val NEWLINE = '\n'
+ val STAR = '*'
+ val SLASH = '/'
+
+ // Possible states
+ object Quote extends Enumeration {
+ type State = Value
+ val InSingleQuote, InDoubleQuote, InComment, InBacktick, NoQuote,
InBracketedComment = Value
+ }
+ import Quote._
+
+ var curState = NoQuote
+ var curIdx = 0
+ val singleCommand = new StringBuilder()
+
+ val skipNextCharacter = () => {
+ curIdx += 1
+ // Optionally append whitespace when skipping next character
+ if (replaceWithWhitespace) {
+ singleCommand.append(" ")
+ }
+ }
+
+ while (curIdx < command.length) {
+ var curChar = command.charAt(curIdx)
+ var appendCharacter = true
+
+ (curState, curChar) match {
+ case (InBracketedComment, STAR) =>
+ val nextIdx = curIdx + 1
+ if (nextIdx < command.length && command.charAt(nextIdx) == SLASH) {
+ curState = NoQuote
+ skipNextCharacter()
+ }
+ appendCharacter = false
+ case (InComment, NEWLINE) =>
+ curState = NoQuote
+ case (InComment, _) =>
+ appendCharacter = false
+ case (InBracketedComment, _) =>
+ appendCharacter = false
+ case (NoQuote, HYPHEN) =>
+ val nextIdx = curIdx + 1
+ if (nextIdx < command.length && command.charAt(nextIdx) == HYPHEN) {
+ appendCharacter = false
+ skipNextCharacter()
+ curState = InComment
+ }
+ case (NoQuote, DOUBLE_QUOTE) => curState = InDoubleQuote
+ case (NoQuote, SINGLE_QUOTE) => curState = InSingleQuote
+ case (NoQuote, BACKTICK) => curState = InBacktick
+ case (NoQuote, SLASH) =>
+ val nextIdx = curIdx + 1
+ if (nextIdx < command.length && command.charAt(nextIdx) == STAR) {
+ appendCharacter = false
+ skipNextCharacter()
+ curState = InBracketedComment
+ }
+ case (InSingleQuote, SINGLE_QUOTE) => curState = NoQuote
+ case (InDoubleQuote, DOUBLE_QUOTE) => curState = NoQuote
+ case (InBacktick, BACKTICK) => curState = NoQuote
+ case (InDoubleQuote | InSingleQuote, BACKSLASH) =>
+ // This is to make sure we are handling \" or \' within "" or ''
correctly.
+ // For example, select "\"--hello--\""
+ val nextIdx = curIdx + 1
+ if (nextIdx < command.length) {
+ singleCommand.append(curChar)
+ curIdx = nextIdx
+ curChar = command.charAt(curIdx)
+ }
+ case (_, _) => ()
+ }
+
+ if (appendCharacter) {
+ singleCommand.append(curChar)
+ } else if (replaceWithWhitespace) {
+ singleCommand.append(" ")
+ }
+ curIdx += 1
+ }
+
+ singleCommand.toString()
+ }
+
+ /**
+ * Check if query is SQL Script.
+ *
+ * @param query The query string.
+ */
+ def isSqlScript(query: String): Boolean = {
+ val cleanText = stripComment(query).trim.toUpperCase(Locale.ROOT)
+ // SQL Stored Procedure body, specified during procedure creation, is also
a SQL Script.
+ (cleanText.startsWith("BEGIN") && (cleanText.endsWith("END") ||
+ cleanText.endsWith("END;"))) || isCreateSqlStoredProcedureText(cleanText)
+ }
+
+ /**
+ * Check if text is create SQL Stored Procedure command.
+ *
+ * @param cleanText The query text, already stripped of comments and
capitalized
+ */
+ private def isCreateSqlStoredProcedureText(cleanText: String): Boolean = {
+ import scala.util.matching.Regex
+
+ val pattern: Regex =
+
"""(?s)^CREATE\s+(OR\s+REPLACE\s+)?PROCEDURE\s+\w+\s*\(.*?\).*BEGIN.*END\s*;?\s*$""".r
+
+ pattern.matches(cleanText)
+ }
+
+ private def containsNonWhiteSpaceCharacters(inputString: String): Boolean = {
+ val pattern = "\\S".r
+ pattern.findFirstIn(inputString).isDefined
+ }
+
+ // Implementation is grabbed from SparkSQLCLIDriver.splitSemiColon, the only
difference is this
+ // implementation handles backtick and treat it as single/double quote.
+ // Below comments are from the source:
+ // Adapted splitSemiColon from Hive 2.3's CliDriver.splitSemiColon.
+ // Note: [SPARK-31595] if there is a `'` in a double quoted string, or a `"`
in a single quoted
+ // string, the origin implementation from Hive will not drop the trailing
semicolon as expected,
+ // hence we refined this function a little bit.
+ // Note: [SPARK-33100] Ignore a semicolon inside a bracketed comment in
spark-sql.
+ def splitSemiColonWithIndex(line: String, enableSqlScripting: Boolean):
List[String] = {
+ var insideSingleQuote = false
+ var insideDoubleQuote = false
+ var insideBacktick = false
+ var insideSimpleComment = false
+ var bracketedCommentLevel = 0
+ var escape = false
+ var beginIndex = 0
+ var leavingBracketedComment = false
+ var hasPrecedingNonCommentString = false
+ var isStatement = false
+ val ret = new ArrayBuffer[String]()
+
+ lazy val insideSqlScript: Boolean = isSqlScript(line)
+
+ def insideBracketedComment: Boolean = bracketedCommentLevel > 0
+ def insideComment: Boolean = insideSimpleComment || insideBracketedComment
+ def statementInProgress(index: Int): Boolean =
+ isStatement || (!insideComment &&
+ index > beginIndex && !s"${line.charAt(index)}".trim.isEmpty)
+
+ for (index <- 0 until line.length) {
+ // Checks if we need to decrement a bracketed comment level; the last
character '/' of
+ // bracketed comments is still inside the comment, so
`insideBracketedComment` must keep
+ // true in the previous loop and we decrement the level here if needed.
+ if (leavingBracketedComment) {
+ bracketedCommentLevel -= 1
+ leavingBracketedComment = false
+ }
+
+ if (line.charAt(index) == '\'' && !insideComment) {
+ // take a look to see if it is escaped
+ // See the comment above about SPARK-31595
+ if (!escape && !insideDoubleQuote && !insideBacktick) {
+ // flip the boolean variable
+ insideSingleQuote = !insideSingleQuote
+ }
+ } else if (line.charAt(index) == '\"' && !insideComment) {
+ // take a look to see if it is escaped
+ // See the comment above about SPARK-31595
+ if (!escape && !insideSingleQuote && !insideBacktick) {
+ // flip the boolean variable
+ insideDoubleQuote = !insideDoubleQuote
+ }
+ } else if (line.charAt(index) == '`' && !insideComment) {
+ // take a look to see if it is escaped
+ // See the comment above about SPARK-31595
+ if (!escape && !insideSingleQuote && !insideDoubleQuote) {
+ // flip the boolean variable
+ insideBacktick = !insideBacktick
+ }
+ } else if (line.charAt(index) == '-') {
+ val hasNext = index + 1 < line.length
+ if (insideDoubleQuote || insideSingleQuote || insideBacktick ||
insideComment) {
+ // Ignores '-' in any case of quotes or comment.
+ // Avoids to start a comment(--) within a quoted segment or already
in a comment.
+ // Sample query: select "quoted value --"
+ // ^^ avoids starting a comment
if inside quotes.
+ } else if (hasNext && line.charAt(index + 1) == '-') {
+ // ignore quotes and ; in simple comment
+ insideSimpleComment = true
+ }
+ } else if (line.charAt(index) == ';') {
+ if (insideSingleQuote || insideDoubleQuote || insideBacktick ||
insideComment) {
+ // do not split
+ } else if (enableSqlScripting && insideSqlScript) {
+ // do not split
+ } else {
+ if (isStatement) {
+ // split, do not include ; itself
+ ret.append(line.substring(beginIndex, index))
+ }
+ beginIndex = index + 1
+ isStatement = false
+ }
+ } else if (line.charAt(index) == '\n') {
+ // with a new line the inline simple comment should end.
+ if (!escape) {
+ insideSimpleComment = false
+ }
+ } else if (line.charAt(index) == '/' && !insideSimpleComment) {
+ val hasNext = index + 1 < line.length
+ if (insideSingleQuote || insideDoubleQuote || insideBacktick) {
+ // Ignores '/' in any case of quotes
+ } else if (insideBracketedComment && line.charAt(index - 1) == '*') {
+ // Decrements `bracketedCommentLevel` at the beginning of the next
loop
+ leavingBracketedComment = true
+ } else if (hasNext && line.charAt(index + 1) == '*') {
+ bracketedCommentLevel += 1
+ // Check if there's non-comment characters(non space, non newline
characters) before
+ // multiline comments.
+ hasPrecedingNonCommentString = beginIndex != index &&
containsNonWhiteSpaceCharacters(
+ line.substring(beginIndex, index)
+ )
+ }
+ }
+ // set the escape
+ if (escape) {
+ escape = false
+ } else if (line.charAt(index) == '\\') {
+ escape = true
+ }
+
+ isStatement = statementInProgress(index)
+ }
+ // Check the last char is end of nested bracketed comment.
+ val endOfBracketedComment = leavingBracketedComment &&
bracketedCommentLevel == 1 &&
+ !hasPrecedingNonCommentString
+ // Spark SQL support simple comment and nested bracketed comment in query
body.
+ // But if Spark SQL receives a comment alone, it will throw parser
exception.
+ // In Spark SQL CLI, if there is a completed comment in the end of whole
query,
+ // since Spark SQL CLL use `;` to split the query, CLI will pass the
comment
+ // to the backend engine and throw exception. CLI should ignore this
comment,
+ // If there is an uncompleted statement or an uncompleted bracketed
comment in the end,
+ // CLI should also pass this part to the backend engine, which may throw
an exception
+ // with clear error message (for incomplete statement, if there's non
comment characters,
+ // we would still append the string).
+ if (!endOfBracketedComment && (isStatement || insideBracketedComment)) {
+ ret.append(line.substring(beginIndex))
+ }
+ ret.toList
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
index fb4053964a84..477ce54f59c4 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -151,4 +151,135 @@ class StringUtilsSuite extends SparkFunSuite with
SQLHelper {
assert(truncatedString(Seq("a", "b", "c"), ", ", Int.MaxValue) === "a, b,
c")
assert(truncatedString(Seq("a", "b", "c"), ", ", Int.MinValue) === "... 3
more fields")
}
+
+ test("SQL comments are stripped correctly") {
+ // single line comment tests
+ assert(stripComment("-- comment") == "")
+ assert(stripComment("--comment") == "")
+ assert(stripComment("-- SELECT * FROM table") == "")
+ assert(stripComment(
+ """-- comment
+ |SELECT * FROM table""".stripMargin) == "\nSELECT * FROM table")
+ assert(stripComment("SELECT * FROM table -- comment") == "SELECT * FROM
table ")
+ assert(stripComment("SELECT '-- not a comment'") == "SELECT '-- not a
comment'")
+ assert(stripComment("SELECT \"-- not a comment\"") == "SELECT \"-- not a
comment\"")
+ assert(stripComment("SELECT 1 -- -- nested comment") == "SELECT 1 ")
+ assert(stripComment("SELECT ' \\' --not a comment'") == "SELECT ' \\'
--not a comment'")
+
+
+ // multiline comment tests
+ assert(stripComment("SELECT /* inline comment */1-- comment") == "SELECT
1")
+ assert(stripComment("SELECT /* inline comment */1") == "SELECT 1")
+ assert(stripComment(
+ """/* my
+ |* multiline
+ | comment */ SELECT * FROM table""".stripMargin) == " SELECT * FROM
table")
+ assert(stripComment("SELECT '/* not a comment */'") == "SELECT '/* not a
comment */'")
+ assert(StringUtils.stripComment(
+ "SELECT \"/* not a comment */\"") == "SELECT \"/* not a comment */\"")
+ assert(stripComment("SELECT 1/* /* nested comment */") == "SELECT 1")
+ assert(stripComment("SELECT ' \\'/*not a comment*/'") == "SELECT '
\\'/*not a comment*/'")
+ }
+
+ test("SQL script detector") {
+ assert(isSqlScript(" BEGIN END"))
+ assert(isSqlScript("BEGIN END;"))
+ assert(isSqlScript("BEGIN END"))
+ assert(isSqlScript(
+ """
+ |BEGIN
+ |
+ |END
+ |""".stripMargin))
+ assert(isSqlScript(
+ """
+ |BEGIN
+ |
+ |END;
+ |""".stripMargin))
+ assert(isSqlScript("BEGIN BEGIN END END"))
+ assert(isSqlScript("BEGIN end"))
+ assert(isSqlScript("begin END"))
+ assert(isSqlScript(
+ """/* header comment
+ |*/
+ |BEGIN
+ |END;
+ |""".stripMargin))
+ assert(isSqlScript(
+ """-- header comment
+ |BEGIN
+ |END;
+ |""".stripMargin))
+ assert(!isSqlScript("-- BEGIN END"))
+ assert(!isSqlScript("/*BEGIN END*/"))
+ assert(isSqlScript("/*BEGIN END*/ BEGIN END"))
+
+ assert(!isSqlScript("CREATE 'PROCEDURE BEGIN' END"))
+ assert(!isSqlScript("CREATE /*PROCEDURE*/ BEGIN END"))
+ assert(!isSqlScript("CREATE PROCEDURE END"))
+ assert(isSqlScript("create ProCeDure p() BEgin END"))
+ assert(isSqlScript("CREATE OR REPLACE PROCEDURE p() BEGIN END"))
+ assert(!isSqlScript("CREATE PROCEDURE BEGIN END")) // procedure must be
named
+ }
+
+ test("SQL string splitter") {
+ // semicolon shouldn't delimit if in quotes
+ assert(
+ splitSemiColonWithIndex(
+ """
+ |SELECT "string;with;semicolons";
+ |USE DATABASE db""".stripMargin,
+ enableSqlScripting = false) == Seq(
+ "\nSELECT \"string;with;semicolons\"",
+ "\nUSE DATABASE db"
+ )
+ )
+
+ // semicolon shouldn't delimit if in backticks
+ assert(
+ splitSemiColonWithIndex(
+ """
+ |SELECT `escaped;sequence;with;semicolons`;
+ |USE DATABASE db""".stripMargin,
+ enableSqlScripting = false) == Seq(
+ "\nSELECT `escaped;sequence;with;semicolons`",
+ "\nUSE DATABASE db"
+ )
+ )
+
+ // white space around command is included in split string
+ assert(
+ splitSemiColonWithIndex(
+ s"""
+ |-- comment 1
+ |-- comment 2
+ |
+ |SELECT 1;\t
+ |-- comment 3
+ |SELECT 2
+ |""".stripMargin,
+ enableSqlScripting = false
+ ) == Seq(
+ "\n-- comment 1\n-- comment 2\n\nSELECT 1",
+ "\t\n-- comment 3\nSELECT 2\n"
+ )
+ )
+
+ // SQL procedures are respected and not split, if configured
+ assert(
+ splitSemiColonWithIndex(
+ """CREATE PROCEDURE p() BEGIN
+ | SELECT 1;
+ | SELECT 2;
+ |END""".stripMargin,
+ enableSqlScripting = true
+ ) == Seq(
+ """CREATE PROCEDURE p() BEGIN
+ | SELECT 1;
+ | SELECT 2;
+ |END""".stripMargin
+ )
+ )
+ }
}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
index b7e0cf86e4dc..648a5154d42e 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala
@@ -166,7 +166,7 @@ trait GraphValidations extends Logging {
/**
* Validates that all flows are resolved. If there are unresolved flows,
- * detects a possible cyclic dependency and throw the appropriate execption.
+ * detects a possible cyclic dependency and throw the appropriate exception.
*/
protected def validateSuccessfulFlowAnalysis(): Unit = {
// all failed flows with their errors
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
index e260d9693b6d..211fbdd4494c 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala
@@ -22,7 +22,6 @@ import scala.util.control.{NonFatal, NoStackTrace}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.pipelines.Language
-import org.apache.spark.sql.pipelines.logging.SourceCodeLocation
/**
* Records information used to track the provenance of a given query to user
code.
@@ -79,17 +78,6 @@ case class QueryOrigin(
)
)
}
-
- /** Generates a SourceCodeLocation using the details present in the query
origin. */
- def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation(
- path = filePath,
- // QueryOrigin tracks line numbers using a 1-indexed numbering scheme
whereas SourceCodeLocation
- // tracks them using a 0-indexed numbering scheme.
- lineNumber = line.map(_ - 1),
- columnNumber = startPosition,
- endingLineNumber = None,
- endingColumnNumber = None
- )
}
object QueryOrigin extends Logging {
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
new file mode 100644
index 000000000000..c24575d58173
--- /dev/null
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOriginType.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.pipelines
+
+object QueryOriginType extends Enumeration {
+ type QueryOriginType = Value
+ val Flow, Table, View = Value
+}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
new file mode 100644
index 000000000000..30fe7c8dd524
--- /dev/null
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/SqlGraphRegistrationContext.scala
@@ -0,0 +1,674 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pipelines.graph
+
+import scala.collection.mutable
+
+import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical.{CreateFlowCommand,
CreateMaterializedViewAsSelect, CreateStreamingTable,
CreateStreamingTableAsSelect, CreateView, InsertIntoStatement, LogicalPlan}
+import org.apache.spark.sql.catalyst.util.StringUtils
+import org.apache.spark.sql.execution.command.{CreateViewCommand,
SetCatalogCommand, SetCommand, SetNamespaceCommand}
+import org.apache.spark.sql.pipelines.{Language, QueryOriginType}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Data class for all state that is accumulated while processing a particular
+ * [[SqlGraphRegistrationContext]].
+ *
+ * @param initialCatalogOpt The initial catalog to assume.
+ * @param initialDatabaseOpt The initial database to assume.
+ * @param initialSqlConf The initial sql confs to assume.
+ */
+class SqlGraphRegistrationContextState(
+ initialCatalogOpt: Option[String],
+ initialDatabaseOpt: Option[String],
+ initialSqlConf: Map[String, String]) {
+ private val sqlConf = mutable.HashMap[String, String](initialSqlConf.toSeq:
_*)
+ private var currentCatalogOpt: Option[String] = initialCatalogOpt
+ private var currentDatabaseOpt: Option[String] = initialDatabaseOpt
+
+ def getSqlConf: Map[String, String] = sqlConf.toMap
+ def getCurrentCatalogOpt: Option[String] = currentCatalogOpt
+ def getCurrentDatabaseOpt: Option[String] = currentDatabaseOpt
+
+ def setSqlConf(k: String, v: String): Unit = sqlConf.put(k, v)
+ def setCurrentCatalog(catalogName: String): Unit = {
+ currentCatalogOpt = Option(catalogName)
+ }
+ def setCurrentDatabase(databaseName: String): Unit = {
+ currentDatabaseOpt = Option(databaseName)
+ }
+ def clearCurrentDatabase(): Unit = {
+ currentDatabaseOpt = None
+ }
+}
+
+case class SqlGraphElementRegistrationException(
+ msg: String,
+ queryOrigin: QueryOrigin) extends AnalysisException(
+ errorClass = "PIPELINE_SQL_GRAPH_ELEMENT_REGISTRATION_ERROR",
+ messageParameters = Map(
+ "message" -> msg,
+ "offendingQuery" ->
SqlGraphElementRegistrationException.offendingQueryString(queryOrigin),
+ "codeLocation" ->
SqlGraphElementRegistrationException.codeLocationStr(queryOrigin)
+ )
+)
+
+object SqlGraphElementRegistrationException {
+ private def codeLocationStr(queryOrigin: QueryOrigin): String =
queryOrigin.filePath match {
+ case Some(fileName) =>
+ queryOrigin.line match {
+ case Some(lineNumber) =>
+ s"Query defined at $fileName:$lineNumber"
+ case None =>
+ s"Query defined in file $fileName"
+ }
+ case None => ""
+ }
+
+ private def offendingQueryString(queryOrigin: QueryOrigin): String =
queryOrigin.sqlText match {
+ case Some(sqlText) =>
+ s"""
+ |Offending query:
+ |${sqlText}
+ |""".stripMargin
+ case None => ""
+ }
+}
+
+/**
+ * SQL statement processor context. At any instant, an instance of this class
holds the "active"
+ * catalog/schema in use within this SQL statement processing context, and
tables/views/flows that
+ * have been registered from SQL statements within this context.
+ */
+class SqlGraphRegistrationContext(
+ graphRegistrationContext: GraphRegistrationContext) {
+ import SqlGraphRegistrationContext._
+
+ private val defaultDatabase = graphRegistrationContext.defaultDatabase
+ private val defaultCatalog = graphRegistrationContext.defaultCatalog
+
+ private val context = new SqlGraphRegistrationContextState(
+ initialCatalogOpt = Option(defaultCatalog),
+ initialDatabaseOpt = Option(defaultDatabase),
+ initialSqlConf = graphRegistrationContext.defaultSqlConf
+ )
+
+ def processSqlFile(sqlText: String, sqlFilePath: String, spark:
SparkSession): Unit = {
+ // Create a registration context for this SQL registration request
+ val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(
+ graphRegistrationContext
+ )
+
+ splitSqlFileIntoQueries(
+ spark = spark,
+ sqlFileText = sqlText,
+ sqlFilePath = sqlFilePath
+ ).foreach { case SqlQueryPlanWithOrigin(logicalPlan, queryOrigin) =>
+ sqlGraphElementRegistrationContext.processSqlQuery(logicalPlan,
queryOrigin)
+ }
+ }
+
+ private def processSqlQuery(queryPlan: LogicalPlan, queryOrigin:
QueryOrigin): Unit = {
+ queryPlan match {
+ case setCommand: SetCommand =>
+ // SET [ key | 'key' ] [ value | 'value' ]
+ // Sets (or overrides if already set) the value for a spark conf key.
Once set, this conf
+ // is applied for all flow functions registered afterward, until
unset/overwritten.
+ SetCommandHandler.handle(setCommand)
+ case setNamespaceCommand: SetNamespaceCommand =>
+ // USE { NAMESPACE | DATABASE | SCHEMA } [ schema_name | 'schema_name'
]
+ // Sets the current schema. After the current schema is set,
unqualified references to
+ // objects such as tables are resolved from said schema, until
overwritten, within this
+ // SQL processor scope.
+ SetNamespaceCommandHandler.handle(setNamespaceCommand)
+ case setCatalogCommand: SetCatalogCommand =>
+ // USE { CATALOG } [ catalog_name | 'catalog_name' ]
+ // Sets the current catalog. After the current catalog is set,
unqualified references to
+ // objects such as tables are resolved from said catalog, until
overwritten, within this
+ // SQL processor scope. Note that the schema is cleared when the
catalog is set, and must
+ // be explicitly set again in order to implicitly qualify identifiers.
+ SetCatalogCommandHandler.handle(setCatalogCommand)
+ case createPersistedViewCommand: CreateView =>
+ // CREATE VIEW [ persisted_view_name ] [ options ] AS [ query ]
+ CreatePersistedViewCommandHandler.handle(createPersistedViewCommand,
queryOrigin)
+ case createTemporaryViewCommand: CreateViewCommand =>
+ // CREATE TEMPORARY VIEW [ temporary_view_name ] [ options ] AS [
query ]
+ CreateTemporaryViewHandler.handle(createTemporaryViewCommand,
queryOrigin)
+ case createMaterializedViewAsSelectCommand:
CreateMaterializedViewAsSelect =>
+ // CREATE MATERIALIZED VIEW [ materialized_view_name ] [ options ] AS
[ query ]
+ CreateMaterializedViewAsSelectHandler.handle(
+ createMaterializedViewAsSelectCommand,
+ queryOrigin
+ )
+ case createStreamingTableAsSelectCommand: CreateStreamingTableAsSelect =>
+ // CREATE STREAMING TABLE [ streaming_table_name ] [ options ] AS [
query ]
+
CreateStreamingTableAsSelectHandler.handle(createStreamingTableAsSelectCommand,
queryOrigin)
+ case createStreamingTableCommand: CreateStreamingTable =>
+ // CREATE STREAMING TABLE [ streaming_table_name ] [ options ]
+ CreateStreamingTableHandler.handle(createStreamingTableCommand,
queryOrigin)
+ case createFlowCommand: CreateFlowCommand =>
+ // CREATE FLOW [ flow_name ] AS INSERT INTO [ destination_name ] BY
NAME
+ CreateFlowHandler.handle(createFlowCommand, queryOrigin)
+ case unsupportedLogicalPlan: LogicalPlan =>
+ throw SqlGraphElementRegistrationException(
+ msg = s"Unsupported plan ${unsupportedLogicalPlan.nodeName} parsed
from SQL query",
+ queryOrigin = queryOrigin
+ )
+ }
+ }
+
+ private object CreateStreamingTableHandler {
+ def handle(cst: CreateStreamingTable, queryOrigin: QueryOrigin): Unit = {
+ val stIdentifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = IdentifierHelper.toTableIdentifier(cst.name),
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+
+ // Register streaming table as a table.
+ graphRegistrationContext.registerTable(
+ Table(
+ identifier = stIdentifier,
+ comment = cst.tableSpec.comment,
+ specifiedSchema =
+
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
+ partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+ properties = cst.tableSpec.properties,
+ baseOrigin = queryOrigin.copy(
+ objectName = Option(stIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Table.toString)
+ ),
+ format = cst.tableSpec.provider,
+ normalizedPath = None,
+ isStreamingTableOpt = None
+ )
+ )
+ }
+ }
+
+ private object CreateStreamingTableAsSelectHandler {
+ def handle(cst: CreateStreamingTableAsSelect, queryOrigin: QueryOrigin):
Unit = {
+ val stIdentifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = IdentifierHelper.toTableIdentifier(cst.name),
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+
+ // Register streaming table as a table.
+ graphRegistrationContext.registerTable(
+ Table(
+ identifier = stIdentifier,
+ comment = cst.tableSpec.comment,
+ specifiedSchema =
+
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
+ partitionCols =
Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
+ properties = cst.tableSpec.properties,
+ baseOrigin = queryOrigin.copy(
+ objectName = Option(stIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Table.toString)
+ ),
+ format = cst.tableSpec.provider,
+ normalizedPath = None,
+ isStreamingTableOpt = None
+ )
+ )
+
+ // Register flow that backs this streaming table.
+ graphRegistrationContext.registerFlow(
+ UnresolvedFlow(
+ identifier = stIdentifier,
+ destinationIdentifier = stIdentifier,
+ func = FlowAnalysis.createFlowFunctionFromLogicalPlan(cst.query),
+ sqlConf = context.getSqlConf,
+ once = false,
+ queryContext = QueryContext(
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ ),
+ comment = cst.tableSpec.comment,
+ origin = queryOrigin.copy(
+ objectName = Option(stIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Flow.toString)
+ )
+ )
+ )
+ }
+ }
+
+ private object CreateMaterializedViewAsSelectHandler {
+ def handle(cmv: CreateMaterializedViewAsSelect, queryOrigin: QueryOrigin):
Unit = {
+ val mvIdentifier = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = IdentifierHelper.toTableIdentifier(cmv.name),
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+
+ // Register materialized view as a table.
+ graphRegistrationContext.registerTable(
+ Table(
+ identifier = mvIdentifier,
+ comment = cmv.tableSpec.comment,
+ specifiedSchema =
+
Option.when(cmv.columns.nonEmpty)(StructType(cmv.columns.map(_.toV1Column))),
+ partitionCols =
Option(PartitionHelper.applyPartitioning(cmv.partitioning, queryOrigin)),
+ properties = cmv.tableSpec.properties,
+ baseOrigin = queryOrigin.copy(
+ objectName = Option(mvIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Table.toString)
+ ),
+ format = cmv.tableSpec.provider,
+ normalizedPath = None,
+ isStreamingTableOpt = None
+ )
+ )
+
+ // Register flow that backs this materialized view.
+ graphRegistrationContext.registerFlow(
+ UnresolvedFlow(
+ identifier = mvIdentifier,
+ destinationIdentifier = mvIdentifier,
+ func = FlowAnalysis.createFlowFunctionFromLogicalPlan(cmv.query),
+ sqlConf = context.getSqlConf,
+ once = false,
+ queryContext = QueryContext(
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ ),
+ comment = cmv.tableSpec.comment,
+ origin = queryOrigin.copy(
+ objectName = Option(mvIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Flow.toString)
+ )
+ )
+ )
+ }
+ }
+
+ private object CreatePersistedViewCommandHandler {
+ def handle(cv: CreateView, queryOrigin: QueryOrigin): Unit = {
+ val viewIdentifier =
GraphIdentifierManager.parseAndValidatePersistedViewIdentifier(
+ rawViewIdentifier = IdentifierHelper.toTableIdentifier(cv.child),
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+
+ // Register persisted view definition.
+ graphRegistrationContext.registerView(
+ PersistedView(
+ identifier = viewIdentifier,
+ comment = cv.comment,
+ origin = queryOrigin.copy(
+ objectName = Option(viewIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.View.toString)
+ ),
+ properties = cv.properties
+ )
+ )
+
+ // Register flow that backs this persisted view.
+ graphRegistrationContext.registerFlow(
+ UnresolvedFlow(
+ identifier = viewIdentifier,
+ destinationIdentifier = viewIdentifier,
+ func = FlowAnalysis.createFlowFunctionFromLogicalPlan(cv.query),
+ sqlConf = context.getSqlConf,
+ once = false,
+ queryContext = QueryContext(
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ ),
+ origin = queryOrigin.copy(
+ objectName = Option(viewIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Flow.toString)
+ ),
+ comment = None
+ )
+ )
+ }
+ }
+
+ private object CreateTemporaryViewHandler {
+ def handle(cvc: CreateViewCommand, queryOrigin: QueryOrigin): Unit = {
+ // Validate the temporary view is not fully qualified, and then qualify
it with the pipeline
+ // catalog/database.
+ val viewIdentifier = GraphIdentifierManager
+ .parseAndValidateTemporaryViewIdentifier(
+ rawViewIdentifier = cvc.name
+ )
+
+ // Register temporary view definition.
+ graphRegistrationContext.registerView(
+ TemporaryView(
+ identifier = viewIdentifier,
+ comment = cvc.comment,
+ origin = queryOrigin.copy(
+ objectName = Option(viewIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.View.toString)
+ ),
+ properties = Map.empty
+ )
+ )
+
+ // Register flow definition that backs this temporary view.
+ graphRegistrationContext.registerFlow(
+ UnresolvedFlow(
+ identifier = viewIdentifier,
+ destinationIdentifier = viewIdentifier,
+ func = FlowAnalysis.createFlowFunctionFromLogicalPlan(cvc.plan),
+ sqlConf = context.getSqlConf,
+ once = false,
+ queryContext = QueryContext(
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ ),
+ origin = queryOrigin.copy(
+ objectName = Option(viewIdentifier.unquotedString),
+ objectType = Option(QueryOriginType.Flow.toString)
+ ),
+ comment = None
+ )
+ )
+ }
+ }
+
+ private object CreateFlowHandler {
+ def handle(cf: CreateFlowCommand, queryOrigin: QueryOrigin): Unit = {
+ val rawFlowIdentifier =
+ IdentifierHelper.toTableIdentifier(cf.name)
+ if (!IdentifierHelper.isSinglePartIdentifier(
+ rawFlowIdentifier
+ )) {
+ throw new AnalysisException(
+ "MULTIPART_FLOW_NAME_NOT_SUPPORTED",
+ Map("flowName" -> rawFlowIdentifier.unquotedString)
+ )
+ }
+
+ val flowIdentifier = GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = rawFlowIdentifier,
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+
+ val (flowTargetDatasetIdentifier, flowQueryLogicalPlan, isOnce) =
cf.flowOperation match {
+ case i: InsertIntoStatement =>
+ validateInsertIntoFlow(i, queryOrigin)
+ val flowTargetDatasetName = i.table match {
+ case u: UnresolvedRelation =>
+ IdentifierHelper.toTableIdentifier(u.multipartIdentifier)
+ case _ =>
+ throw SqlGraphElementRegistrationException(
+ msg = "Unable to resolve target dataset name for INSERT INTO
flow",
+ queryOrigin = queryOrigin
+ )
+ }
+ val qualifiedFlowTargetDatasetName = GraphIdentifierManager
+ .parseAndQualifyTableIdentifier(
+ rawTableIdentifier = flowTargetDatasetName,
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+ (qualifiedFlowTargetDatasetName, i.query, false)
+ case _ =>
+ throw SqlGraphElementRegistrationException(
+ msg = "Unable flow type. Only INSERT INTO flows are supported.",
+ queryOrigin = queryOrigin
+ )
+ }
+
+ val qualifiedDestinationIdentifier = GraphIdentifierManager
+ .parseAndQualifyFlowIdentifier(
+ rawFlowIdentifier = flowTargetDatasetIdentifier,
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ )
+ .identifier
+
+ graphRegistrationContext.registerFlow(
+ UnresolvedFlow(
+ identifier = flowIdentifier,
+ destinationIdentifier = qualifiedDestinationIdentifier,
+ comment = cf.comment,
+ func =
FlowAnalysis.createFlowFunctionFromLogicalPlan(flowQueryLogicalPlan),
+ sqlConf = context.getSqlConf,
+ once = isOnce,
+ queryContext = QueryContext(
+ currentCatalog = context.getCurrentCatalogOpt,
+ currentDatabase = context.getCurrentDatabaseOpt
+ ),
+ origin = queryOrigin
+ )
+ )
+ }
+
+ private def validateInsertIntoFlow(
+ insertIntoStatement: InsertIntoStatement,
+ queryOrigin: QueryOrigin
+ ): Unit = {
+ if (insertIntoStatement.partitionSpec.nonEmpty) {
+ throw SqlGraphElementRegistrationException(
+ msg = "Partition spec may not be specified for flow target.",
+ queryOrigin = queryOrigin
+ )
+ }
+ if (insertIntoStatement.userSpecifiedCols.nonEmpty) {
+ throw SqlGraphElementRegistrationException(
+ msg = "Column schema may not be specified for flow target.",
+ queryOrigin = queryOrigin
+ )
+ }
+ if (insertIntoStatement.overwrite) {
+ throw SqlGraphElementRegistrationException(
+ msg = "INSERT OVERWRITE flows not supported.",
+ queryOrigin = queryOrigin
+ )
+ }
+ if (insertIntoStatement.ifPartitionNotExists) {
+ throw SqlGraphElementRegistrationException(
+ msg = "IF NOT EXISTS not supported for flows.",
+ queryOrigin = queryOrigin
+ )
+ }
+ if (!insertIntoStatement.byName) {
+ throw SqlGraphElementRegistrationException(
+ msg = "Only INSERT INTO by name flows supported.",
+ queryOrigin = queryOrigin
+ )
+ }
+ }
+ }
+
+ private object SetCommandHandler {
+ def handle(setCommand: SetCommand): Unit = {
+ val sqlConfKvPair = setCommand.kv.getOrElse(
+ throw new RuntimeException("Invalid SET command without key-value
pair")
+ )
+ val sqlConfKey = sqlConfKvPair._1
+ val sqlConfValue = sqlConfKvPair._2.getOrElse(
+ throw new RuntimeException("Invalid SET command without value")
+ )
+ context.setSqlConf(sqlConfKey, sqlConfValue)
+ }
+ }
+
+ private object SetNamespaceCommandHandler {
+ def handle(setNamespaceCommand: SetNamespaceCommand): Unit = {
+ setNamespaceCommand.namespace match {
+ case Seq(database) =>
+ context.setCurrentDatabase(database)
+ case Seq(catalog, database) =>
+ context.setCurrentCatalog(catalog)
+ context.setCurrentDatabase(database)
+ case invalidSchemaIdentifier =>
+ throw new SparkException(
+ "Invalid schema identifier provided on USE command: " +
+ s"$invalidSchemaIdentifier"
+ )
+ }
+ }
+ }
+
+ private object SetCatalogCommandHandler {
+ def handle(setCatalogCommand: SetCatalogCommand): Unit = {
+ context.setCurrentCatalog(setCatalogCommand.catalogName)
+ context.clearCurrentDatabase()
+ }
+ }
+}
+
+object PartitionHelper {
+ import org.apache.spark.sql.connector.expressions.{IdentityTransform,
Transform}
+
+ def applyPartitioning(partitioning: Seq[Transform], queryOrigin:
QueryOrigin): Seq[String] = {
+ partitioning.foreach {
+ case _: IdentityTransform =>
+ case other =>
+ throw SqlGraphElementRegistrationException(
+ msg = s"Invalid partitioning transform ($other)",
+ queryOrigin = queryOrigin
+ )
+ }
+ partitioning.collect {
+ case t: IdentityTransform =>
+ if (t.references.length != 1) {
+ throw SqlGraphElementRegistrationException(
+ msg = "Only single column based partitioning is supported.",
+ queryOrigin = queryOrigin
+ )
+ }
+ if (t.ref.fieldNames().length != 1) {
+ throw SqlGraphElementRegistrationException(
+ msg = "Multipart partition identifier not allowed.",
+ queryOrigin = queryOrigin
+ )
+ }
+ t.ref.fieldNames().head
+ }
+ }
+}
+
+object SqlGraphRegistrationContext {
+ /**
+ * Split SQL statements by semicolon.
+ *
+ * Note that an input SQL text/blob like:
+ * "-- comment 1
+ * SELECT 1;
+ *
+ * SELECT 2 ; -- comment 2"
+ *
+ * Will be split into the two following strings:
+ * "-- comment 1
+ * SELECT 1",
+ * "
+ * SELECT 2 "
+ *
+ * The semicolon that terminates a statement is not included in the returned
string for that
+ * statement, any white space/comments surrounding a statement is included
in the returned
+ * string for that statement, and any white space/comments following the
last semicolon
+ * terminated statement is not returned.
+ */
+ private def splitSqlTextBySemicolon(sqlText: String): List[String] =
StringUtils
+ .splitSemiColonWithIndex(line = sqlText, enableSqlScripting = false)
+
+ /** Class that holds the logical plan and query origin parsed from a SQL
statement. */
+ case class SqlQueryPlanWithOrigin(plan: LogicalPlan, queryOrigin:
QueryOrigin)
+
+ /**
+ * Given a SQL file (raw text content and path), return the parsed logical
plan and query origin
+ * per SQL statement in the file contents.
+ *
+ * Note that the returned origins will not be complete - origin information
like object name and
+ * type will only be determined and populate when the logical plan is
inspected during SQL
+ * element registration.
+ *
+ * @param spark the spark session to use to parse SQL statements.
+ * @param sqlFileText the raw text content of the SQL file.
+ * @param sqlFilePath the file path to the SQL file. Only used to populate
the query origin.
+ * @return a [[SqlQueryPlanWithOrigin]] object per SQL statement, in the
same order the SQL
+ * statements were defined in the file contents.
+ */
+ def splitSqlFileIntoQueries(
+ spark: SparkSession,
+ sqlFileText: String,
+ sqlFilePath: String
+ ): Seq[SqlQueryPlanWithOrigin] = {
+ // The index in the file we've processed up to at this point
+ var currentCharIndexInFile = 0
+
+ val rawSqlStatements = splitSqlTextBySemicolon(sqlFileText)
+ rawSqlStatements.map { rawSqlStatement =>
+ val rawSqlStatementText = rawSqlStatement
+ val logicalPlanFromSqlQuery =
spark.sessionState.sqlParser.parsePlan(rawSqlStatementText)
+
+ // Update and return the query origin, accounting for the position of
the statement with
+ // respect to the entire file.
+
+ // The actual start position of the SQL query in this sqlText string.
Within sqlText it's
+ // possible that whitespace or comments precede the start of the query,
and is accounted for
+ // in the parsed logical plan's origin.
+ val sqlStatementStartIdxInString =
logicalPlanFromSqlQuery.origin.startIndex.getOrElse(
+ throw new SparkRuntimeException(
+ errorClass = "INTERNAL_ERROR",
+ messageParameters = Map(
+ "message" ->
+ s"""Unable to retrieve start index of logical plan parsed by the
following
+ |SQL text:
+ |
+ |$rawSqlStatementText""".stripMargin)
+ )
+ )
+
+ // The actual start position of the SQL query in the entire file.
+ val sqlStatementStartIndexInFile = currentCharIndexInFile +
sqlStatementStartIdxInString
+
+ // The line number is the number of new lines characters found prior to
the start of this sql
+ // statement, plus 1 for 1-indexing. Ex. "SELECT 1;" should be on line
1, not line 0.
+ val sqlStatementLineNumber = 1 + sqlFileText.substring(0,
sqlStatementStartIndexInFile)
+ .count(_ == '\n')
+
+ // Move the current char index/ptr by the length of the raw SQL text we
just processed, plus
+ // 1 to account for delimiting semicolon.
+ currentCharIndexInFile += rawSqlStatementText.length + 1
+
+ // Return the updated query origin with line number and start position.
+ SqlQueryPlanWithOrigin(
+ plan = logicalPlanFromSqlQuery,
+ queryOrigin = QueryOrigin(
+ language = Option(Language.Sql()),
+ filePath = Option(sqlFilePath),
+ // Raw SQL text, after stripping away preceding whitespace
+ sqlText =
Option(rawSqlStatementText.substring(sqlStatementStartIdxInString)),
+ line = Option(sqlStatementLineNumber),
+ startPosition = Option(sqlStatementStartIndexInFile)
+ )
+ )
+ }
+ }
+}
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala
index b96bd64a9ce6..ac239c7156b1 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala
@@ -62,7 +62,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = s"Flow ${flow.displayName} is QUEUED.",
@@ -81,7 +81,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(batchFlow.displayName),
datasetName = None,
- sourceCodeLocation = Option(batchFlow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(batchFlow.origin)
),
level = EventLevel.INFO,
message = s"Flow ${batchFlow.displayName} is PLANNING.",
@@ -102,7 +102,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flowExecution.displayName),
datasetName = None,
- sourceCodeLocation =
Option(flowExecution.getOrigin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flowExecution.getOrigin)
),
level = EventLevel.INFO,
message = s"Flow ${flowExecution.displayName} is STARTING.",
@@ -119,7 +119,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = s"Flow ${flow.displayName} is RUNNING.",
@@ -147,7 +147,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = if (logAsWarn) EventLevel.WARN else EventLevel.ERROR,
message = eventLogMessage,
@@ -170,7 +170,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.WARN,
message = s"Flow '${flow.displayName}' SKIPPED due to upstream
failure(s).",
@@ -193,7 +193,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = {
@@ -213,7 +213,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = s"Flow '${flow.displayName}' is EXCLUDED.",
@@ -237,7 +237,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = message.getOrElse(s"Flow '${flow.displayName}' has
STOPPED."),
@@ -257,7 +257,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = s"Flow '${flow.displayName}' is IDLE, waiting for new data.",
@@ -282,7 +282,7 @@ class FlowProgressEventLogger(eventBuffer:
PipelineRunEventBuffer) extends Loggi
origin = PipelineEventOrigin(
flowName = Option(flow.displayName),
datasetName = None,
- sourceCodeLocation = Option(flow.origin.toSourceCodeLocation)
+ sourceCodeLocation = Option(flow.origin)
),
level = EventLevel.INFO,
message = s"Flow ${flow.displayName} has COMPLETED.",
diff --git
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala
index 90dcbc6e911f..2f3d782cd853 100644
---
a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala
+++
b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.pipelines.logging
import org.apache.spark.sql.pipelines.common.{FlowStatus, RunState}
+import org.apache.spark.sql.pipelines.graph.QueryOrigin
/**
* An internal event that is emitted during the run of a pipeline.
@@ -47,23 +48,7 @@ case class PipelineEvent(
case class PipelineEventOrigin(
datasetName: Option[String],
flowName: Option[String],
- sourceCodeLocation: Option[SourceCodeLocation]
-)
-
-/**
- * Describes the location of the source code
- * @param path The path to the source code
- * @param lineNumber The line number of the source code
- * @param columnNumber The column number of the source code
- * @param endingLineNumber The ending line number of the source code
- * @param endingColumnNumber The ending column number of the source code
- */
-case class SourceCodeLocation(
- path: Option[String],
- lineNumber: Option[Int],
- columnNumber: Option[Int],
- endingLineNumber: Option[Int],
- endingColumnNumber: Option[Int]
+ sourceCodeLocation: Option[QueryOrigin]
)
// Additional details about the PipelineEvent
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
new file mode 100644
index 000000000000..130a024f2bb1
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlPipelineSuite.scala
@@ -0,0 +1,855 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.pipelines.utils.{PipelineTest,
TestGraphRegistrationContext}
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{LongType, StructType}
+import org.apache.spark.util.Utils
+
+class SqlPipelineSuite extends PipelineTest with SQLTestUtils {
+ private val externalTable1Ident = TableIdentifier(
+ table = "external_t1",
+ database = Option(TestGraphRegistrationContext.DEFAULT_DATABASE),
+ catalog = Option(TestGraphRegistrationContext.DEFAULT_CATALOG)
+ )
+ private val externalTable2Ident = TableIdentifier(
+ table = "external_t2",
+ database = Option(TestGraphRegistrationContext.DEFAULT_DATABASE),
+ catalog = Option(TestGraphRegistrationContext.DEFAULT_CATALOG)
+ )
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ // Create mock external tables that tests can reference, ex. to stream
from.
+ spark.sql(s"CREATE TABLE $externalTable1Ident AS SELECT * FROM RANGE(3)")
+ spark.sql(s"CREATE TABLE $externalTable2Ident AS SELECT * FROM RANGE(4)")
+ }
+
+ override def afterEach(): Unit = {
+ spark.sql(s"DROP TABLE IF EXISTS $externalTable1Ident")
+ spark.sql(s"DROP TABLE IF EXISTS $externalTable2Ident")
+ super.afterEach()
+ }
+
+ test("Simple register SQL dataset test") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE MATERIALIZED VIEW mv AS SELECT 1;
+ |CREATE STREAMING TABLE st AS SELECT * FROM STREAM
$externalTable1Ident;
+ |CREATE VIEW v AS SELECT * FROM mv;
+ |CREATE FLOW f AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable2Ident;
+ |""".stripMargin
+ )
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ assert(resolvedDataflowGraph.flows.size == 4)
+ assert(resolvedDataflowGraph.tables.size == 2)
+ assert(resolvedDataflowGraph.views.size == 1)
+
+ val mvFlow =
+ resolvedDataflowGraph.resolvedFlows
+ .filter(_.identifier == fullyQualifiedIdentifier("mv"))
+ .head
+ assert(mvFlow.inputs.map(_.table) == Set())
+ assert(mvFlow.destinationIdentifier == fullyQualifiedIdentifier("mv"))
+
+ val stFlow =
+ resolvedDataflowGraph.resolvedFlows
+ .filter(_.identifier == fullyQualifiedIdentifier("st"))
+ .head
+ // The streaming table has 1 external input, and no internal (defined
within pipeline) inputs
+ assert(stFlow.funcResult.usedExternalInputs == Set(externalTable1Ident))
+ assert(stFlow.inputs.isEmpty)
+ assert(stFlow.destinationIdentifier == fullyQualifiedIdentifier("st"))
+
+ val viewFlow =
+ resolvedDataflowGraph.resolvedFlows
+ .filter(_.identifier == fullyQualifiedIdentifier("v"))
+ .head
+ assert(viewFlow.inputs == Set(fullyQualifiedIdentifier("mv")))
+ assert(viewFlow.destinationIdentifier == fullyQualifiedIdentifier("v"))
+
+ val namedFlow =
+ resolvedDataflowGraph.resolvedFlows.filter(_.identifier ==
fullyQualifiedIdentifier("f")).head
+ assert(namedFlow.funcResult.usedExternalInputs == Set(externalTable2Ident))
+ assert(namedFlow.inputs.isEmpty)
+ assert(namedFlow.destinationIdentifier == fullyQualifiedIdentifier("st"))
+ }
+
+ test("Duplicate table name across different SQL files fails") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = "CREATE STREAMING TABLE table;",
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = """
+ |CREATE VIEW table AS SELECT 1;
+ |""".stripMargin,
+ sqlFilePath = "b.sql",
+ spark = spark
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ graphRegistrationContext.toDataflowGraph
+ },
+ condition = "PIPELINE_DUPLICATE_IDENTIFIERS.DATASET",
+ sqlState = Option("42710"),
+ parameters = Map(
+ "datasetName" -> fullyQualifiedIdentifier("table").quotedString,
+ "datasetType1" -> "TABLE",
+ "datasetType2" -> "VIEW"
+ )
+ )
+ }
+
+ test("Static pipeline dataset resolves correctly") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText =
+ "CREATE MATERIALIZED VIEW a COMMENT 'this is a comment' AS SELECT *
FROM range(1, 4)"
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ val flowA =
+ resolvedDataflowGraph.resolvedFlows
+ .filter(_.identifier == fullyQualifiedIdentifier("a"))
+ .head
+ assert(flowA.comment.contains("this is a comment"))
+ checkAnswer(flowA.df, Seq(Row(1), Row(2), Row(3)))
+ }
+
+ test("Special characters in dataset name allowed when escaped") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW `hyphen-mv` AS SELECT * FROM
range(1, 4);
+ |CREATE MATERIALIZED VIEW `other-hyphen-mv` AS SELECT * FROM
`hyphen-mv`
+ |""".stripMargin
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ assert(
+ resolvedDataflowGraph.resolvedFlows
+ .exists(f => f.identifier == fullyQualifiedIdentifier("hyphen-mv") &&
!f.df.isStreaming)
+ )
+
+ assert(
+ resolvedDataflowGraph.resolvedFlows
+ .exists(f =>
+ f.identifier == fullyQualifiedIdentifier("other-hyphen-mv") &&
!f.df.isStreaming)
+ )
+ }
+
+ test("Pipeline with batch dependencies is correctly resolved") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE VIEW b as SELECT * FROM a;
+ |CREATE VIEW a AS SELECT * FROM range(1, 4);
+ |CREATE VIEW c AS SELECT * FROM `b`;
+ |CREATE VIEW d AS SELECT * FROM c
+ |""".stripMargin
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ Seq("a", "b", "c", "d").foreach { datasetName =>
+ val backingFlow = resolvedDataflowGraph.resolvedFlows
+ .find(_.identifier == fullyQualifiedIdentifier(datasetName))
+ .head
+ checkAnswer(backingFlow.df, Seq(Row(1), Row(2), Row(3)))
+ }
+ }
+
+ test("Pipeline dataset can be referenced in subquery") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW A AS SELECT * FROM RANGE(5);
+ |CREATE MATERIALIZED VIEW B AS SELECT * FROM RANGE(5)
+ |WHERE id = (SELECT max(id) FROM A);
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark
+ .sql(s"SELECT * FROM ${fullyQualifiedIdentifier("B").quotedString}"),
+ Row(4)
+ )
+ }
+
+ test("Pipeline datasets can have dependency on streaming table") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE STREAMING TABLE a AS SELECT * FROM
STREAM($externalTable1Ident);
+ |CREATE MATERIALIZED VIEW b AS SELECT * FROM a;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark
+ .sql(s"SELECT * FROM ${fullyQualifiedIdentifier("b").quotedString}"),
+ Seq(Row(0), Row(1), Row(2))
+ )
+ }
+
+ test("SQL aggregation works") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText =
+ """
+ |CREATE MATERIALIZED VIEW a AS SELECT id AS value, (id % 2) AS isOdd
FROM range(1,10);
+ |CREATE MATERIALIZED VIEW b AS SELECT isOdd, max(value) AS
+ |maximum FROM a GROUP BY isOdd LIMIT 2;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark
+ .sql(s"SELECT * FROM ${fullyQualifiedIdentifier("b").quotedString}"),
+ Seq(Row(0, 8), Row(1, 9))
+ )
+ }
+
+ test("SQL join works") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE TEMPORARY VIEW a AS SELECT id FROM range(1,3);
+ |CREATE TEMPORARY VIEW b AS SELECT id FROM range(1,3);
+ |CREATE MATERIALIZED VIEW c AS SELECT a.id AS id1, b.id AS
id2
+ |FROM a JOIN b ON a.id=b.id
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark
+ .sql(s"SELECT * FROM ${fullyQualifiedIdentifier("c").quotedString}"),
+ Seq(Row(1, 1), Row(2, 2))
+ )
+ }
+
+ test("Partition cols correctly registered") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW a
+ |PARTITIONED BY (id1, id2)
+ |AS SELECT id as id1, id as id2 FROM range(1,2)
""".stripMargin
+ )
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ assert(
+ resolvedDataflowGraph.tables
+ .find(_.identifier == fullyQualifiedIdentifier("a"))
+ .head
+ .partitionCols
+ .contains(Seq("id1", "id2"))
+ )
+ }
+
+ test("Exception is thrown when non-identity partition columns are used") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ val ex = intercept[SqlGraphElementRegistrationException] {
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW a
+ |PARTITIONED BY (year(id1))
+ |AS SELECT id as id1, id as id2 FROM
range(1,2)""".stripMargin,
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+ }
+
+ assert(ex.getMessage.contains("Invalid partitioning transform"))
+ }
+
+ test("Table properties are correctly registered") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = "CREATE STREAMING TABLE st TBLPROPERTIES ('prop1'='foo',
'prop2'='bar') AS SELECT 1"
+ )
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+ assert(
+ resolvedDataflowGraph.tables
+ .find(_.identifier == fullyQualifiedIdentifier("st"))
+ .head
+ .properties == Map(
+ "prop1" -> "foo",
+ "prop2" -> "bar"
+ )
+ )
+ }
+
+ test("Spark confs are correctly registered") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW a AS SELECT id FROM range(1,2);
+ |SET conf.test = a;
+ |CREATE VIEW b AS SELECT id from range(1,2);
+ |SET conf.test = b;
+ |SET conf.test2 = c;
+ |CREATE STREAMING TABLE c AS SELECT id FROM range(1,2);
+ |""".stripMargin
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ assert(
+ resolvedDataflowGraph.flows
+ .find(_.identifier == fullyQualifiedIdentifier("a"))
+ .head
+ .sqlConf == Map.empty
+ )
+
+ assert(
+ resolvedDataflowGraph.flows
+ .find(_.identifier == fullyQualifiedIdentifier("b"))
+ .head
+ .sqlConf == Map(
+ "conf.test" -> "a"
+ )
+ )
+
+ assert(
+ resolvedDataflowGraph.flows
+ .find(_.identifier == fullyQualifiedIdentifier("c"))
+ .head
+ .sqlConf == Map(
+ "conf.test" -> "b",
+ "conf.test2" -> "c"
+ )
+ )
+ }
+
+ test("Setting dataset location is disallowed") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ val ex = intercept[ParseException] {
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = """CREATE STREAMING TABLE a
+ |LOCATION "/path/to/table"
+ |AS SELECT * FROM range(1,2)""".stripMargin,
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+ }
+
+ assert(ex.getMessage.contains("Specifying location is not supported"))
+ }
+
+ test("Tables living in arbitrary schemas can be read from pipeline") {
+ val database_name = "db_c215a150_c9c1_4c65_bc02_f7d50dea2f5d"
+ val table_name = s"$database_name.tbl"
+ spark.sql(s"CREATE DATABASE $database_name")
+ spark.sql(s"CREATE TABLE $table_name AS SELECT * FROM range(1,4)")
+
+ withDatabase(database_name) {
+ withTable(table_name) {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE MATERIALIZED VIEW a AS SELECT * FROM
$table_name;
+ |CREATE STREAMING TABLE b AS SELECT * FROM
STREAM($table_name);
+ |CREATE TEMPORARY VIEW c AS SELECT * FROM a;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ Seq("a", "b").foreach { tableName =>
+ checkAnswer(
+ spark
+ .sql(
+ s"SELECT * FROM
${fullyQualifiedIdentifier(tableName).quotedString}"
+ ),
+ Seq(Row(1), Row(2), Row(3))
+ )
+ }
+ }
+ }
+ }
+
+ gridTest(s"Pipeline dataset can read from file based data sources")(
+ Seq("parquet", "orc", "json", "csv")
+ ) { fileFormat =>
+ val tmpDir = Utils.createTempDir().getAbsolutePath
+ spark.sql("SELECT * FROM
RANGE(3)").write.format(fileFormat).mode("overwrite").save(tmpDir)
+
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE MATERIALIZED VIEW a AS SELECT * FROM
$fileFormat.`$tmpDir`;
+ |CREATE STREAMING TABLE b AS SELECT * FROM
STREAM($fileFormat.`$tmpDir`)
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ Seq("a", "b").foreach { datasetName =>
+ val datasetFullyQualifiedName =
+ fullyQualifiedIdentifier(datasetName).quotedString
+ spark.sql(s"REFRESH TABLE $datasetFullyQualifiedName")
+ val expectedRows = if (fileFormat == "csv") {
+ // CSV values are read as strings
+ Seq("0", "1", "2")
+ } else {
+ Seq(0, 1, 2)
+ }
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $datasetFullyQualifiedName"),
+ expectedRows.map(Row(_))
+ )
+ }
+ }
+
+ gridTest("Invalid reads produce correct error message")(
+ Seq(
+ ("csv.``", "The location name cannot be empty string, but `` was
given."),
+ ("csv.`/non/existing/file`", "Path does not exist:
file:/non/existing/file")
+ )
+ ) {
+ case (path, expectedErrorMsg) =>
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE MATERIALIZED VIEW a AS SELECT * FROM $path;
+ |""".stripMargin
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ assert(
+ resolvedDataflowGraph.resolutionFailedFlows
+ .find(_.identifier == fullyQualifiedIdentifier("a"))
+ .head
+ .failure
+ .head
+ .getMessage
+ .contains(expectedErrorMsg)
+ )
+ }
+
+ test("Pipeline dataset can be referenced in CTE") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW a AS SELECT 1;
+ |CREATE MATERIALIZED VIEW d AS
+ |WITH c AS (
+ | WITH b AS (
+ | SELECT * FROM a
+ | )
+ | SELECT * FROM b
+ |)
+ |SELECT * FROM c;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark
+ .sql(s"SELECT * FROM ${fullyQualifiedIdentifier("d").quotedString}"),
+ Row(1)
+ )
+ }
+
+ test("Unsupported SQL statements throws error") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ val ex = intercept[SqlGraphElementRegistrationException] {
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = "CREATE TABLE t AS SELECT 1",
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+ }
+
+ assert(ex.getMessage.contains("Unsupported plan"))
+ }
+
+ test("Table schema is correctly parsed") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = """
+ |CREATE MATERIALIZED VIEW a (id LONG COMMENT 'comment') AS
SELECT * FROM RANGE(5);
+ |""".stripMargin
+ )
+
+ val resolvedDataflowGraph = unresolvedDataflowGraph.resolve()
+
+ // Let inferred/declared schema mismatch detection execute
+ resolvedDataflowGraph.validate()
+
+ val expectedSchema = new StructType().add(name = "id", dataType =
LongType, nullable = false)
+
+ assert(
+ resolvedDataflowGraph.resolvedFlows
+ .find(_.identifier == fullyQualifiedIdentifier("a"))
+ .head
+ .schema == expectedSchema
+ )
+ }
+
+ test("Multipart table names supported") {
+ val database_name = "db_4159cf91_42c1_44d6_aa8c_9cd8a158230d"
+ val database2_name = "db_a90d194f_9dfd_44bf_b473_26727e76be7a"
+ spark.sql(s"CREATE DATABASE $database_name")
+ spark.sql(s"CREATE DATABASE $database2_name")
+
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE MATERIALIZED VIEW $database_name.mv1 AS SELECT 1;
+ |CREATE MATERIALIZED VIEW $database2_name.mv2 AS SELECT *
FROM $database_name.mv1
+ |""".stripMargin
+ )
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(spark.sql(s"SELECT * FROM $database_name.mv1"), Row(1))
+ checkAnswer(spark.sql(s"SELECT * FROM $database2_name.mv2"), Row(1))
+ }
+
+ test("Flow cannot be created with multipart identifier") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ val ex = intercept[AnalysisException] {
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = s"""
+ |CREATE STREAMING TABLE st;
+ |CREATE FLOW some_database.f AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable1Ident;
+ |""".stripMargin,
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+ }
+
+ assert(ex.getMessage.contains("Flow with multipart name 'some_database.f'
is not supported"))
+ }
+
+ test("Temporary view cannot be created with multipart identifier") {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ val sqlGraphRegistrationContext = new
SqlGraphRegistrationContext(graphRegistrationContext)
+
+ val ex = intercept[ParseException] {
+ sqlGraphRegistrationContext.processSqlFile(
+ sqlText = """
+ |CREATE TEMPORARY VIEW some_database.tv AS SELECT 1;
+ |CREATE MATERIALIZED VIEW mv AS SELECT * FROM
some_database.tv;
+ |""".stripMargin,
+ sqlFilePath = "a.sql",
+ spark = spark
+ )
+ }
+
+ assert(ex.errorClass.contains("TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS"))
+ }
+
+ test("Use database and set catalog works") {
+ val pipelineCatalog = TestGraphRegistrationContext.DEFAULT_CATALOG
+ val pipelineDatabase = TestGraphRegistrationContext.DEFAULT_DATABASE
+ val otherCatalog = "c_bb3e5598_be3c_4250_a3e1_92c2a75bd3ce"
+ val otherDatabase = "db_caa1d504_ceb5_40e5_b444_ada891288f07"
+
+ // otherDatabase2 will be created in the pipeline's catalog
+ val otherDatabase2 = "db_8b1e9b89_99d8_4f5e_91af_da5b9091143c"
+
+ spark.conf.set(
+ key = s"spark.sql.catalog.$otherCatalog",
+ value = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"
+ )
+ spark.sql(s"CREATE DATABASE $otherCatalog.$otherDatabase")
+ spark.sql(s"CREATE DATABASE $otherDatabase2")
+
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSqlFiles(
+ sqlFiles = Seq(
+ TestSqlFile(
+ sqlText =
+ s"""
+ |-- Create table in default (pipeline) catalog and database
+ | CREATE MATERIALIZED VIEW mv AS SELECT * FROM RANGE(3);
+ |
+ |-- Change database
+ |USE DATABASE $otherDatabase2;
+ |
+ |-- Create mv2 under new database, implicitly
+ |CREATE MATERIALIZED VIEW mv2 AS SELECT * FROM RANGE(1, 5);
+ |
+ |-- Unqualified names in subquery should implicitly use current
database.
+ |-- This should work as current database lives under test_db
catalog
+ |-- Fully qualified names in subquery should use specified
database.
+ |CREATE MATERIALIZED VIEW mv3 AS
+ |WITH mv AS (SELECT * FROM
$pipelineCatalog.$pipelineDatabase.mv)
+ |SELECT mv2.id FROM
+ |mv JOIN mv2 ON mv.id=mv2.id;
+ |
+ |-- Change to database that lives in another catalog. Same
behavior expected.
+ |SET CATALOG $otherCatalog;
+ |USE DATABASE $otherDatabase;
+ |
+ |-- Create temporary view. Temporary views should always be
created in the pipeline
+ |-- catalog and database, regardless of what the active
catalog/database are.
+ |CREATE TEMPORARY VIEW tv AS SELECT * FROM
$pipelineCatalog.$pipelineDatabase.mv;
+ |
+ |CREATE MATERIALIZED VIEW mv4 AS
+ |WITH mv2 AS (SELECT * FROM
$pipelineCatalog.$otherDatabase2.mv2)
+ |SELECT * FROM STREAM(mv2) WHERE mv2.id % 2 == 0;
+ |
+ |-- Use namespace command should also work, setting both
catalog and database.
+ |USE NAMESPACE $pipelineCatalog.$otherDatabase2;
+ |-- mv2 was originally created in this same namespace, so
implicit qualification
+ |-- should work.
+ |CREATE MATERIALIZED VIEW mv5 AS SELECT * FROM mv2;
+ |
+ |-- Temp views, which don't support name qualification, should
always resolve to
+ |-- pipeline catalog and database despite the active
catalog/database
+ |CREATE MATERIALIZED VIEW mv6 AS SELECT * FROM tv;
+ |""".stripMargin,
+ sqlFilePath = "file1.sql"
+ ),
+ TestSqlFile(
+ sqlText =
+ s"""
+ |-- The previous file's current catalog/database should not
impact other files;
+ |-- the catalog/database should be reset to the pipeline's.
+ |--
+ |-- Should also be able to read dataset created in other file
with custom catalog
+ |-- and database.
+ |CREATE MATERIALIZED VIEW mv6 AS SELECT * FROM
$pipelineCatalog.$otherDatabase2.mv5;
+ |""".stripMargin,
+ sqlFilePath = "file2.sql"
+ )
+ )
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $pipelineCatalog.$otherDatabase2.mv3"),
+ Seq(Row(1), Row(2))
+ )
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $otherCatalog.$otherDatabase.mv4"),
+ Seq(Row(2), Row(4))
+ )
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $otherDatabase2.mv5"),
+ Seq(Row(1), Row(2), Row(3), Row(4))
+ )
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $otherDatabase2.mv6"),
+ Seq(Row(0), Row(1), Row(2))
+ )
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $pipelineCatalog.$pipelineDatabase.mv6"),
+ Seq(Row(1), Row(2), Row(3), Row(4))
+ )
+ }
+
+ test("Writing/reading datasets from fully and partially qualified names
works") {
+
spark.catalog.setCurrentCatalog(TestGraphRegistrationContext.DEFAULT_CATALOG)
+
spark.catalog.setCurrentDatabase(TestGraphRegistrationContext.DEFAULT_DATABASE)
+
+ val otherCatalog = "c_25fcf574_171a_4058_b17d_2e38622b702b"
+ val otherDatabase = "db_3558273c_9843_4eac_8b9c_cf7a7d144371"
+ val otherDatabase2 = "db_6ae79691_11f8_43a6_9120_ca7c7dee662a"
+
+ spark.conf.set(
+ key = s"spark.sql.catalog.$otherCatalog",
+ value = "org.apache.spark.sql.connector.catalog.InMemoryTableCatalog"
+ )
+ spark.sql(s"CREATE DATABASE $otherCatalog.$otherDatabase")
+ spark.sql(s"CREATE DATABASE $otherDatabase2")
+
+ // Note: we are intentionally not testing using streaming tables, as the
InMemoryManagedCatalog
+ // does not support streaming reads/writes, and checkpoint locations
cannot be materialized
+ // without catalog-provided hard storage.
+ Seq(
+ ("upstream_mv", "downstream_mv"),
+ ("upstream_mv2", s"$otherDatabase2.downstream_mv2"),
+ ("upstream_mv3", s"$otherCatalog.$otherDatabase.downstream_mv3"),
+ (s"$otherDatabase2.upstream_mv4", "downstream_mv4"),
+ (s"$otherCatalog.$otherDatabase.upstream_mv5", "downstream_mv5"),
+ (s"$otherCatalog.$otherDatabase.upstream_mv6",
s"$otherDatabase2.downstream_mv6")
+ ).foreach { case (table1Ident, table2Ident) =>
+ // The pipeline catalog is
[[TestGraphRegistrationContext.DEFAULT_CATALOG]] and database is
+ // [[TestGraphRegistrationContext.DEFAULT_DATABASE]].
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText =
+ s"""
+ |CREATE MATERIALIZED VIEW $table1Ident (id BIGINT) AS SELECT *
FROM RANGE(10);
+ |CREATE MATERIALIZED VIEW $table2Ident AS SELECT id FROM
$table1Ident
+ |WHERE (id%2)=0;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM $table2Ident"),
+ Seq(Row(0), Row(2), Row(4), Row(6), Row(8))
+ )
+
+ spark.sql(s"DROP TABLE $table1Ident")
+ spark.sql(s"DROP TABLE $table2Ident")
+ }
+ }
+
+ test("Creating streaming table without subquery works if streaming table is
backed by flows") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = s"""
+ |CREATE STREAMING TABLE st;
+ |CREATE FLOW f AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable1Ident;
+ |""".stripMargin
+ )
+
+ startPipelineAndWaitForCompletion(unresolvedDataflowGraph)
+
+ checkAnswer(
+ spark.sql(s"SELECT * FROM ${fullyQualifiedIdentifier("st")}"),
+ Seq(Row(0), Row(1), Row(2))
+ )
+ }
+
+ test("Empty streaming table definition is disallowed") {
+ val unresolvedDataflowGraph = unresolvedDataflowGraphFromSql(
+ sqlText = "CREATE STREAMING TABLE st;"
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ unresolvedDataflowGraph
+ .resolve()
+ .validate()
+ },
+ condition = "PIPELINE_DATASET_WITHOUT_FLOW",
+ sqlState = Option("0A000"),
+ parameters = Map("identifier" ->
fullyQualifiedIdentifier("st").quotedString)
+ )
+ }
+
+ test("Flow identifiers must be single part") {
+ Seq("a.b", "a.b.c").foreach { flowIdentifier =>
+ val ex = intercept[AnalysisException] {
+ unresolvedDataflowGraphFromSql(
+ sqlText =
+ s"""
+ |CREATE STREAMING TABLE st;
+ |CREATE FLOW $flowIdentifier AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable1Ident
+ |""".stripMargin
+ )
+ }
+ checkError(
+ exception = ex,
+ condition = "MULTIPART_FLOW_NAME_NOT_SUPPORTED",
+ parameters = Map("flowName" -> flowIdentifier)
+ )
+ }
+ }
+
+ test("Duplicate standalone flow identifiers throw an exception") {
+ val ex = intercept[AnalysisException] {
+ // even if flows are defined across multiple files, if there's a
duplicate flow identifier an
+ // exception should be thrown.
+ unresolvedDataflowGraphFromSqlFiles(
+ sqlFiles = Seq(
+ TestSqlFile(
+ sqlText =
+ s"""
+ |CREATE STREAMING TABLE st;
+ |CREATE FLOW f AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable1Ident
+ |""".stripMargin,
+ sqlFilePath = "file1.sql"
+ ),
+ TestSqlFile(
+ sqlText =
+ s"""
+ |CREATE FLOW f AS INSERT INTO st BY NAME
+ |SELECT * FROM STREAM $externalTable1Ident
+ |""".stripMargin,
+ sqlFilePath = "file2.sql"
+ )
+ )
+ )
+ }
+ checkError(
+ exception = ex,
+ condition = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW",
+ parameters = Map(
+ "flowName" -> fullyQualifiedIdentifier("f").unquotedString,
+ "datasetNames" -> fullyQualifiedIdentifier("st").quotedString
+ )
+ )
+ }
+
+ test("Duplicate standalone implicit flow identifier throws exception") {
+ val ex = intercept[AnalysisException] {
+ // even if flows are defined across multiple files, if there's a
duplicate flow identifier an
+ // exception should be thrown.
+ unresolvedDataflowGraphFromSqlFiles(
+ sqlFiles = Seq(
+ TestSqlFile(
+ sqlText =
+ s"""
+ |CREATE STREAMING TABLE st AS SELECT * FROM STREAM
$externalTable1Ident;
+ |CREATE STREAMING TABLE st2;
+ |""".stripMargin,
+ sqlFilePath = "file1.sql"
+ ),
+ TestSqlFile(
+ sqlText =
+ s"""
+ |CREATE FLOW st AS INSERT INTO st2 BY NAME
+ |SELECT * FROM STREAM $externalTable2Ident
+ |""".stripMargin,
+ sqlFilePath = "file2.sql"
+ )
+ )
+ )
+ }
+ checkError(
+ exception = ex,
+ condition = "PIPELINE_DUPLICATE_IDENTIFIERS.FLOW",
+ parameters = Map(
+ "flowName" -> fullyQualifiedIdentifier("st").unquotedString,
+ "datasetNames" -> Seq(fullyQualifiedIdentifier("st").quotedString,
+ fullyQualifiedIdentifier("st2").quotedString).mkString(",")
+ )
+ )
+ }
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlQueryOriginSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlQueryOriginSuite.scala
new file mode 100644
index 000000000000..58c71fe961c8
--- /dev/null
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SqlQueryOriginSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pipelines.graph
+
+import org.apache.spark.sql.pipelines.Language
+import org.apache.spark.sql.pipelines.utils.PipelineTest
+
+class SqlQueryOriginSuite extends PipelineTest {
+ test("basic test") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText =
+ """-- comment 1
+ |CREATE MATERIALIZED VIEW a.b.c AS SELECT 1;
+ |
+ |USE DATABASE d ; -- comment 2
+ |""".stripMargin
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option(
+ """CREATE MATERIALIZED VIEW a.b.c AS SELECT 1""".stripMargin),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ startPosition = Option(13),
+ line = Option(2)
+ ),
+ QueryOrigin(
+ sqlText = Option(
+ """USE DATABASE d """.stripMargin),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ startPosition = Option(58),
+ line = Option(4)
+ )
+ ))
+ }
+
+ test("\\n in sql file is not considered a new line") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText =
+ """CREATE STREAMING TABLE `a.\n` AS SELECT "\n";
+ |CREATE VIEW my_view AS SELECT * FROM `a.\n`;
+ |""".stripMargin
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option("CREATE STREAMING TABLE `a.\\n` AS SELECT \"\\n\""),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ startPosition = Option(0),
+ line = Option(1)
+ ),
+ QueryOrigin(
+ sqlText = Option("CREATE VIEW my_view AS SELECT * FROM `a.\\n`"),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ startPosition = Option(46),
+ line = Option(2)
+ )
+ ))
+ }
+
+ test("White space is accounted for in startPosition") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText =
+ s"""
+ | ${"\t"}CREATE FLOW f AS INSERT INTO t BY NAME SELECT 1;
+ |""".stripMargin
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option("CREATE FLOW f AS INSERT INTO t BY NAME SELECT 1"),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ // 1 new line, 5 spaces, 1 tab
+ startPosition = Option(7),
+ line = Option(2)
+ )
+ ))
+ }
+
+ test("Multiline SQL statement line number is the first line of the
statement") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText =
+ s"""
+ |CREATE
+ |MATERIALIZED VIEW mv
+ |AS
+ |SELECT 1;
+ |""".stripMargin
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option(
+ s"""CREATE
+ |MATERIALIZED VIEW mv
+ |AS
+ |SELECT 1""".stripMargin
+ ),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ // 1 new line
+ startPosition = Option(1),
+ line = Option(2)
+ )
+ ))
+ }
+
+ test("Preceeding comment is ommitted from sql text") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText =
+ s"""
+ |-- comment
+ |CREATE MATERIALIZED VIEW mv -- another comment
+ |AS SELECT 1;
+ |""".stripMargin
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option(
+ s"""CREATE MATERIALIZED VIEW mv -- another comment
+ |AS SELECT 1""".stripMargin
+ ),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ // 1 new line, 10 chars for preceding comment, another new line
+ startPosition = Option(12),
+ line = Option(3)
+ )
+ ))
+ }
+
+ test("Semicolon in string literal does not cause statement split") {
+ val sqlQueryOrigins = SqlGraphRegistrationContext.splitSqlFileIntoQueries(
+ spark,
+ sqlFilePath = "file.sql",
+ sqlFileText = "CREATE TEMPORARY VIEW v AS SELECT 'my ; string';"
+ ).map(_.queryOrigin)
+ assert(sqlQueryOrigins == Seq(
+ QueryOrigin(
+ sqlText = Option("CREATE TEMPORARY VIEW v AS SELECT 'my ; string'"),
+ filePath = Option("file.sql"),
+ language = Option(Language.Sql()),
+ startPosition = Option(0),
+ line = Option(1)
+ )
+ ))
+ }
+}
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala
index 81d5758b19d5..a9837d5e5f1e 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala
@@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.pipelines.common.FlowStatus
+import org.apache.spark.sql.pipelines.graph.QueryOrigin
class ConstructPipelineEventSuite extends SparkFunSuite with
BeforeAndAfterEach {
@@ -100,12 +101,10 @@ class ConstructPipelineEventSuite extends SparkFunSuite
with BeforeAndAfterEach
datasetName = Some("dataset"),
flowName = Some("flow"),
sourceCodeLocation = Some(
- SourceCodeLocation(
- path = Some("path"),
- lineNumber = None,
- columnNumber = None,
- endingLineNumber = None,
- endingColumnNumber = None
+ QueryOrigin(
+ filePath = Some("path"),
+ line = None,
+ startPosition = None
)
)
),
@@ -116,7 +115,7 @@ class ConstructPipelineEventSuite extends SparkFunSuite
with BeforeAndAfterEach
)
assert(event.origin.datasetName.contains("dataset"))
assert(event.origin.flowName.contains("flow"))
- assert(event.origin.sourceCodeLocation.get.path.contains("path"))
+ assert(event.origin.sourceCodeLocation.get.filePath.contains("path"))
assert(event.level == EventLevel.INFO)
assert(event.message == "Flow 'b' has failed")
assert(event.details.asInstanceOf[FlowProgress].status ==
FlowStatus.FAILED)
diff --git
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
index 8f0b2689d27b..a54c09d9e251 100644
---
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
+++
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.SparkSession.{clearActiveSession,
clearDefaultSessio
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession,
SQLContext}
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.pipelines.graph.{DataflowGraph,
PipelineUpdateContextImpl, SqlGraphRegistrationContext}
import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore,
createTempDir}
abstract class PipelineTest
@@ -55,6 +56,17 @@ abstract class PipelineTest
implicit def sqlContext: SQLContext = spark.sqlContext
def sql(text: String): DataFrame = spark.sql(text)
+ protected def startPipelineAndWaitForCompletion(unresolvedDataflowGraph:
DataflowGraph): Unit = {
+ val updateContext = new PipelineUpdateContextImpl(
+ unresolvedDataflowGraph, eventCallback = _ => ())
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+ }
+
+ /**
+ * Spark confs for [[originalSpark]]. Spark confs set here will be the
default spark confs for
+ * all spark sessions created in tests.
+ */
protected def sparkConf: SparkConf = {
new SparkConf()
.set("spark.sql.shuffle.partitions", "2")
@@ -89,6 +101,36 @@ abstract class PipelineTest
}
}
+ /** Helper class to represent a SQL file by its contents and path. */
+ protected case class TestSqlFile(sqlText: String, sqlFilePath: String)
+
+ /** Construct an unresolved DataflowGraph object from possibly multiple SQL
files. */
+ protected def unresolvedDataflowGraphFromSqlFiles(
+ sqlFiles: Seq[TestSqlFile]
+ ): DataflowGraph = {
+ val graphRegistrationContext = new TestGraphRegistrationContext(spark)
+ sqlFiles.foreach { sqlFile =>
+ new SqlGraphRegistrationContext(graphRegistrationContext).processSqlFile(
+ sqlText = sqlFile.sqlText,
+ sqlFilePath = sqlFile.sqlFilePath,
+ spark = spark
+ )
+ }
+ graphRegistrationContext
+ .toDataflowGraph
+ }
+
+ /** Construct an unresolved DataflowGraph object from a single SQL file,
given the file contents
+ * and path. */
+ protected def unresolvedDataflowGraphFromSql(
+ sqlText: String,
+ sqlFilePath: String = "dataset.sql"
+ ): DataflowGraph = {
+ unresolvedDataflowGraphFromSqlFiles(
+ Seq(TestSqlFile(sqlText = sqlText, sqlFilePath = sqlFilePath))
+ )
+ }
+
/**
* This exists temporarily for compatibility with tests that become invalid
when multiple
* executors are available.
@@ -154,7 +196,7 @@ abstract class PipelineTest
namedGridTest(testNamePrefix, testTags: _*)(params.map(a => a.toString ->
a).toMap)(testFun)
}
- override def test(testName: String, testTags: Tag*)(testFun: => Any /*
Assertion */ )(
+ override protected def test(testName: String, testTags: Tag*)(testFun: =>
Any /* Assertion */ )(
implicit pos: source.Position): Unit = super.test(testName, testTags:
_*) {
runWithInstrumentation(testFun)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]