This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bba6839d8714 [SPARK-50762][SQL] Add Analyzer rule for resolving SQL
scalar UDFs
bba6839d8714 is described below
commit bba6839d87144a251464bda410540e9877cbba2b
Author: Allison Wang <[email protected]>
AuthorDate: Tue Jan 14 14:58:58 2025 +0800
[SPARK-50762][SQL] Add Analyzer rule for resolving SQL scalar UDFs
### What changes were proposed in this pull request?
This PR adds a new Analyzer rule `ResolveSQLFunctions` to resolve scalar
SQL UDFs by replacing a `SQLFunctionExpression` with an actual function body.
It currently supports the following operators: Project, Filter, Join and
Aggregate.
For example:
```
CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE
RETURN width * height;
```
and this query
```
SELECT area(a, b) FROM t;
```
will be resolved as
```
Project [area(width, height) AS area]
+- Project [a, b, CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
+- Relation [a, b]
```
### Why are the changes needed?
To support SQL UDFs.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New SQL query tests. More tests will be added once table function
resolution is supported.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49414 from allisonwang-db/spark-50762-resolve-scalar-udf.
Authored-by: Allison Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 13 +
.../sql/catalyst/expressions/ExpressionInfo.java | 2 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 272 ++++++++++
.../sql/catalyst/analysis/CheckAnalysis.scala | 2 +
.../catalyst/analysis/SQLFunctionExpression.scala | 53 +-
.../sql/catalyst/catalog/SessionCatalog.scala | 103 +++-
.../sql/catalyst/catalog/UserDefinedFunction.scala | 21 +
.../optimizer/EliminateSQLFunctionNode.scala | 47 ++
.../spark/sql/catalyst/optimizer/Optimizer.scala | 1 +
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../sql-tests/analyzer-results/sql-udf.sql.out | 575 +++++++++++++++++++++
.../test/resources/sql-tests/inputs/sql-udf.sql | 122 +++++
.../resources/sql-tests/results/sql-udf.sql.out | 484 +++++++++++++++++
.../spark/sql/execution/SQLFunctionSuite.scala | 61 +++
.../sql/expressions/ExpressionInfoSuite.scala | 3 +-
15 files changed, 1753 insertions(+), 7 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 8b266e9d6ac1..5037b5247542 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -3126,6 +3126,13 @@
],
"sqlState" : "42K08"
},
+ "INVALID_SQL_FUNCTION_PLAN_STRUCTURE" : {
+ "message" : [
+ "Invalid SQL function plan structure",
+ "<plan>"
+ ],
+ "sqlState" : "XXKD0"
+ },
"INVALID_SQL_SYNTAX" : {
"message" : [
"Invalid SQL syntax:"
@@ -5757,6 +5764,12 @@
],
"sqlState" : "0A000"
},
+ "UNSUPPORTED_SQL_UDF_USAGE" : {
+ "message" : [
+ "Using SQL function <functionName> in <nodeName> is not supported."
+ ],
+ "sqlState" : "0A000"
+ },
"UNSUPPORTED_STREAMING_OPERATOR_WITHOUT_WATERMARK" : {
"message" : [
"<outputMode> output mode not supported for <statefulOperator> on
streaming DataFrames/DataSets without watermark."
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
index 4200619d3c5f..310d18ddb348 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
@@ -51,7 +51,7 @@ public class ExpressionInfo {
"window_funcs", "xml_funcs", "table_funcs", "url_funcs",
"variant_funcs"));
private static final Set<String> validSources =
- new HashSet<>(Arrays.asList("built-in", "hive", "python_udf",
"scala_udf",
+ new HashSet<>(Arrays.asList("built-in", "hive", "python_udf",
"scala_udf", "sql_udf",
"java_udf", "python_udtf", "internal"));
public String getClassName() {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9282e0554a2d..92cfc4119dd0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -374,6 +374,7 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
BindProcedures ::
ResolveTableSpec ::
ValidateAndStripPipeExpressions ::
+ ResolveSQLFunctions ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
@@ -2364,6 +2365,277 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
}
}
+ /**
+ * This rule resolves SQL function expressions. It pulls out function inputs
and place them
+ * in a separate [[Project]] node below the operator and replace the SQL
function with its
+ * actual function body. SQL function expressions in [[Aggregate]] are
handled in a special
+ * way. Non-aggregated SQL functions in the aggregate expressions of an
Aggregate need to be
+ * pulled out into a Project above the Aggregate before replacing the SQL
function expressions
+ * with actual function bodies. For example:
+ *
+ * Before:
+ * Aggregate [c1] [foo(c1), foo(max(c2)), sum(foo(c2)) AS sum]
+ * +- Relation [c1, c2]
+ *
+ * After:
+ * Project [foo(c1), foo(max_c2), sum]
+ * +- Aggregate [c1] [c1, max(c2) AS max_c2, sum(foo(c2)) AS sum]
+ * +- Relation [c1, c2]
+ */
+ object ResolveSQLFunctions extends Rule[LogicalPlan] {
+
+ private def hasSQLFunctionExpression(exprs: Seq[Expression]): Boolean = {
+ exprs.exists(_.find(_.isInstanceOf[SQLFunctionExpression]).nonEmpty)
+ }
+
+ /**
+ * Check if the function input contains aggregate expressions.
+ */
+ private def checkFunctionInput(f: SQLFunctionExpression): Unit = {
+ if (f.inputs.exists(AggregateExpression.containsAggregate)) {
+ // The input of a SQL function should not contain aggregate functions
after
+ // `extractAndRewrite`. If there are aggregate functions, it means
they are
+ // nested in another aggregate function, which is not allowed.
+ // For example: SELECT sum(foo(sum(c1))) FROM t
+ // We have to throw the error here because otherwise the query plan
after
+ // resolving the SQL function will not be valid.
+ throw new AnalysisException(
+ errorClass = "NESTED_AGGREGATE_FUNCTION",
+ messageParameters = Map.empty)
+ }
+ }
+
+ /**
+ * Resolve a SQL function expression as a logical plan check if it can be
analyzed.
+ */
+ private def resolve(f: SQLFunctionExpression): LogicalPlan = {
+ // Validate the SQL function input.
+ checkFunctionInput(f)
+ val plan = v1SessionCatalog.makeSQLFunctionPlan(f.name, f.function,
f.inputs)
+ val resolved = SQLFunctionContext.withSQLFunction {
+ // Resolve the SQL function plan using its context.
+ val conf = new SQLConf()
+ f.function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k,
v) }
+ SQLConf.withExistingConf(conf) {
+ executeSameContext(plan)
+ }
+ }
+ // Fail the analysis eagerly if a SQL function cannot be resolved using
its input.
+ SimpleAnalyzer.checkAnalysis(resolved)
+ resolved
+ }
+
+ /**
+ * Rewrite SQL function expressions into actual resolved function bodies
and extract
+ * function inputs into the given project list.
+ */
+ private def rewriteSQLFunctions[E <: Expression](
+ expression: E,
+ projectList: ArrayBuffer[NamedExpression]): E = {
+ val newExpr = expression match {
+ case f: SQLFunctionExpression if !hasSQLFunctionExpression(f.inputs) &&
+ // Make sure LateralColumnAliasReference in parameters is resolved
and eliminated first.
+ // Otherwise, the projectList can contain the
LateralColumnAliasReference, which will be
+ // pushed down to a Project without the 'referenced' alias by LCA
present, leaving it
+ // unresolved.
+ !f.inputs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))
=>
+ withPosition(f) {
+ val plan = resolve(f)
+ // Extract the function input project list from the SQL function
plan and
+ // inline the SQL function expression.
+ plan match {
+ case Project(body :: Nil, Project(aliases, _: LocalRelation)) =>
+ projectList ++= aliases
+ SQLScalarFunction(f.function, aliases.map(_.toAttribute), body)
+ case o =>
+ throw new AnalysisException(
+ errorClass = "INVALID_SQL_FUNCTION_PLAN_STRUCTURE",
+ messageParameters = Map("plan" -> o.toString))
+ }
+ }
+ case o => o.mapChildren(rewriteSQLFunctions(_, projectList))
+ }
+ newExpr.asInstanceOf[E]
+ }
+
+ /**
+ * Check if the given expression contains expressions that should be
extracted,
+ * i.e. non-aggregated SQL functions with non-foldable inputs.
+ */
+ private def shouldExtract(e: Expression): Boolean = e match {
+ // Return false if the expression is already an aggregate expression.
+ case _: AggregateExpression => false
+ case _: SQLFunctionExpression => true
+ case _: LeafExpression => false
+ case o => o.children.exists(shouldExtract)
+ }
+
+ /**
+ * Extract aggregate expressions from the given expression and replace
+ * them with attribute references.
+ * Example:
+ * Before: foo(c1) + foo(max(c2)) + max(foo(c2))
+ * After: foo(c1) + foo(max_c2) + max_foo_c2
+ * Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS
max_foo_c2]
+ */
+ private def extractAndRewrite[T <: Expression](
+ expression: T,
+ extractedExprs: ArrayBuffer[NamedExpression]): T = {
+ val newExpr = expression match {
+ case e if !shouldExtract(e) =>
+ val exprToAdd: NamedExpression = e match {
+ case o: OuterReference => Alias(o, toPrettySQL(o.e))()
+ case ne: NamedExpression => ne
+ case o => Alias(o, toPrettySQL(o))()
+ }
+ extractedExprs += exprToAdd
+ exprToAdd.toAttribute
+ case f: SQLFunctionExpression =>
+ val newInputs = f.inputs.map(extractAndRewrite(_, extractedExprs))
+ f.copy(inputs = newInputs)
+ case o => o.mapChildren(extractAndRewrite(_, extractedExprs))
+ }
+ newExpr.asInstanceOf[T]
+ }
+
+ /**
+ * Replace all [[SQLFunctionExpression]]s in an expression with attribute
references
+ * from the aliasMap.
+ */
+ private def replaceSQLFunctionWithAttr[T <: Expression](
+ expr: T,
+ aliasMap: mutable.HashMap[Expression, Alias]): T = {
+ expr.transform {
+ case f: SQLFunctionExpression if aliasMap.contains(f.canonicalized) =>
+ aliasMap(f.canonicalized).toAttribute
+ }.asInstanceOf[T]
+ }
+
+ private def rewrite(plan: LogicalPlan): LogicalPlan = plan match {
+ // Return if a sub-tree does not contain SQLFunctionExpression.
+ case p: LogicalPlan if !p.containsPattern(SQL_FUNCTION_EXPRESSION) => p
+
+ case f @ Filter(cond, a: Aggregate)
+ if !f.resolved || AggregateExpression.containsAggregate(cond) ||
+ ResolveGroupingAnalytics.hasGroupingFunction(cond) ||
+ cond.containsPattern(TEMP_RESOLVED_COLUMN) =>
+ // If the filter's condition contains aggregate expressions or
grouping expressions or temp
+ // resolved column, we cannot rewrite both the filter and the
aggregate until they are
+ // resolved by ResolveAggregateFunctions or ResolveGroupingAnalytics,
because rewriting SQL
+ // functions in aggregate can add an additional project on top of the
aggregate
+ // which breaks the pattern matching in those rules.
+ f.copy(child = a.copy(child = rewrite(a.child)))
+
+ case h @ UnresolvedHaving(_, a: Aggregate) =>
+ // Similarly UnresolvedHaving should be resolved by
ResolveAggregateFunctions first
+ // before rewriting aggregate.
+ h.copy(child = a.copy(child = rewrite(a.child)))
+
+ case a: Aggregate if a.resolved &&
hasSQLFunctionExpression(a.expressions) =>
+ val child = rewrite(a.child)
+ // Extract SQL functions in the grouping expressions and place them in
a project list
+ // below the current aggregate. Also update their appearances in the
aggregate expressions.
+ val bottomProjectList = ArrayBuffer.empty[NamedExpression]
+ val aliasMap = mutable.HashMap.empty[Expression, Alias]
+ val newGrouping = a.groupingExpressions.map { expr =>
+ expr.transformDown {
+ case f: SQLFunctionExpression =>
+ val alias = aliasMap.getOrElseUpdate(f.canonicalized, Alias(f,
f.name)())
+ bottomProjectList += alias
+ alias.toAttribute
+ }
+ }
+ val aggregateExpressions = a.aggregateExpressions.map(
+ replaceSQLFunctionWithAttr(_, aliasMap))
+
+ // Rewrite SQL functions in the aggregate expressions that are not
wrapped in
+ // aggregate functions. They need to be extracted into a project list
above the
+ // current aggregate.
+ val aggExprs = ArrayBuffer.empty[NamedExpression]
+ val topProjectList = aggregateExpressions.map(extractAndRewrite(_,
aggExprs))
+
+ // Rewrite SQL functions in the new aggregate expressions that are
wrapped inside
+ // aggregate functions.
+ val newAggExprs = aggExprs.map(rewriteSQLFunctions(_,
bottomProjectList))
+
+ val bottomProject = if (bottomProjectList.nonEmpty) {
+ Project(child.output ++ bottomProjectList, child)
+ } else {
+ child
+ }
+ val newAgg = if (newGrouping.nonEmpty || newAggExprs.nonEmpty) {
+ a.copy(
+ groupingExpressions = newGrouping,
+ aggregateExpressions = newAggExprs.toSeq,
+ child = bottomProject)
+ } else {
+ bottomProject
+ }
+ if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else
newAgg
+
+ case p: Project if p.resolved && hasSQLFunctionExpression(p.expressions)
=>
+ val newChild = rewrite(p.child)
+ val projectList = ArrayBuffer.empty[NamedExpression]
+ val newPList = p.projectList.map(rewriteSQLFunctions(_, projectList))
+ if (newPList != newChild.output) {
+ p.copy(newPList, Project(newChild.output ++ projectList, newChild))
+ } else {
+ assert(projectList.isEmpty)
+ p.copy(child = newChild)
+ }
+
+ case f: Filter if f.resolved && hasSQLFunctionExpression(f.expressions)
=>
+ val newChild = rewrite(f.child)
+ val projectList = ArrayBuffer.empty[NamedExpression]
+ val newCond = rewriteSQLFunctions(f.condition, projectList)
+ if (newCond != f.condition) {
+ Project(f.output, Filter(newCond, Project(newChild.output ++
projectList, newChild)))
+ } else {
+ assert(projectList.isEmpty)
+ f.copy(child = newChild)
+ }
+
+ case j: Join if j.resolved && hasSQLFunctionExpression(j.expressions) =>
+ val newLeft = rewrite(j.left)
+ val newRight = rewrite(j.right)
+ val projectList = ArrayBuffer.empty[NamedExpression]
+ val joinCond = j.condition.map(rewriteSQLFunctions(_, projectList))
+ if (joinCond != j.condition) {
+ // Join condition cannot have non-deterministic expressions. We can
safely
+ // replace the aliases with the original SQL function input
expressions.
+ val aliasMap = projectList.collect { case a: Alias => a.toAttribute
-> a.child }.toMap
+ val newJoinCond = joinCond.map(_.transform {
+ case a: Attribute => aliasMap.getOrElse(a, a)
+ })
+ j.copy(left = newLeft, right = newRight, condition = newJoinCond)
+ } else {
+ assert(projectList.isEmpty)
+ j.copy(left = newLeft, right = newRight)
+ }
+
+ case o: LogicalPlan if o.resolved &&
hasSQLFunctionExpression(o.expressions) =>
+
o.transformExpressionsWithPruning(_.containsPattern(SQL_FUNCTION_EXPRESSION)) {
+ case f: SQLFunctionExpression =>
+ f.failAnalysis(
+ errorClass = "UNSUPPORTED_SQL_UDF_USAGE",
+ messageParameters = Map(
+ "functionName" -> toSQLId(f.function.name.nameParts),
+ "nodeName" -> o.nodeName.toString))
+ }
+
+ case p: LogicalPlan => p.mapChildren(rewrite)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ // Only rewrite SQL functions when they are not in nested function calls.
+ if (SQLFunctionContext.get.nestedSQLFunctionDepth > 0) {
+ plan
+ } else {
+ rewrite(plan)
+ }
+ }
+ }
+
/**
* Turns projections that contain aggregate expressions into aggregations.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 46ca8e793218..0a68524c3124 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -1106,6 +1106,8 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
@scala.annotation.tailrec
def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child)
+ // Skip SQL function node added by the Analyzer
+ case s: SQLFunctionNode => cleanQueryInScalarSubquery(s.child)
case p: Project => cleanQueryInScalarSubquery(p.child)
case h: ResolvedHint => cleanQueryInScalarSubquery(h.child)
case child => child
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
index fb6935d64d4c..37981f47287d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.catalog.SQLFunction
-import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
-import
org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION,
TreePattern}
+import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression,
Unevaluable}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION,
SQL_SCALAR_FUNCTION, TreePattern}
import org.apache.spark.sql.types.DataType
/**
@@ -39,3 +39,52 @@ case class SQLFunctionExpression(
newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs
= newChildren)
final override val nodePatterns: Seq[TreePattern] =
Seq(SQL_FUNCTION_EXPRESSION)
}
+
+/**
+ * A wrapper node for a SQL scalar function expression.
+ */
+case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression],
child: Expression)
+ extends UnaryExpression with Unevaluable {
+ override def dataType: DataType = child.dataType
+ override def toString: String = s"${function.name}(${inputs.mkString(", ")})"
+ override def sql: String =
s"${function.name}(${inputs.map(_.sql).mkString(", ")})"
+ override protected def withNewChildInternal(newChild: Expression):
SQLScalarFunction = {
+ copy(child = newChild)
+ }
+ final override val nodePatterns: Seq[TreePattern] = Seq(SQL_SCALAR_FUNCTION)
+ // The `inputs` is for display only and does not matter in execution.
+ override lazy val canonicalized: Expression = copy(inputs = Nil, child =
child.canonicalized)
+ override lazy val deterministic: Boolean = {
+ function.deterministic.getOrElse(true) && children.forall(_.deterministic)
+ }
+}
+
+/**
+ * Provide a way to keep state during analysis for resolving nested SQL
functions.
+ *
+ * @param nestedSQLFunctionDepth The nested depth in the SQL function
resolution. A SQL function
+ * expression should only be expanded as a
[[SQLScalarFunction]] if
+ * the nested depth is 0.
+ */
+case class SQLFunctionContext(nestedSQLFunctionDepth: Int = 0)
+
+object SQLFunctionContext {
+
+ private val value = new ThreadLocal[SQLFunctionContext]() {
+ override def initialValue: SQLFunctionContext = SQLFunctionContext()
+ }
+
+ def get: SQLFunctionContext = value.get()
+
+ def reset(): Unit = value.remove()
+
+ private def set(context: SQLFunctionContext): Unit = value.set(context)
+
+ def withSQLFunction[A](f: => A): A = {
+ val originContext = value.get()
+ val context = originContext.copy(
+ nestedSQLFunctionDepth = originContext.nestedSQLFunctionDepth + 1)
+ set(context)
+ try f finally { set(originContext) }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 3c6dfe5ac844..b123952c5f08 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -38,9 +38,9 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
import org.apache.spark.sql.catalyst.catalog.SQLFunction.parseDefault
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression,
UpCast}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
Cast, Expression, ExpressionInfo, NamedArgumentExpression, NamedExpression,
ScalarSubquery, UpCast}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature,
InputParameter, LogicalPlan, NamedParametersSupport, Project, SubqueryAlias,
View}
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature,
InputParameter, LocalRelation, LogicalPlan, NamedParametersSupport, Project,
SubqueryAlias, View}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils}
import org.apache.spark.sql.connector.catalog.CatalogManager
@@ -48,7 +48,7 @@ import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, PartitioningUtils}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
@@ -1561,6 +1561,103 @@ class SessionCatalog(
}
}
+ /**
+ * Constructs a scalar SQL function logical plan. The logical plan will be
used to
+ * construct actual expression from the function inputs and body.
+ *
+ * The body of a scalar SQL function can either be an expression or a query
returns
+ * one single column.
+ *
+ * Example scalar SQL function with an expression:
+ *
+ * CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE
+ * RETURN width * height;
+ *
+ * Query:
+ *
+ * SELECT area(a, b) FROM t;
+ *
+ * SQL function plan:
+ *
+ * Project [CAST(width * height AS DOUBLE) AS area]
+ * +- Project [CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
+ * +- LocalRelation [a, b]
+ *
+ * Example scalar SQL function with a subquery:
+ *
+ * CREATE FUNCTION foo(x INT) RETURNS INT
+ * RETURN SELECT SUM(b) FROM t WHERE x = a;
+ *
+ * SELECT foo(a) FROM t;
+ *
+ * SQL function plan:
+ *
+ * Project [scalar-subquery AS foo]
+ * : +- Aggregate [] [sum(b)]
+ * : +- Filter [outer(x) = a]
+ * : +- Relation [a, b]
+ * +- Project [CAST(a AS INT) AS x]
+ * +- LocalRelation [a, b]
+ */
+ def makeSQLFunctionPlan(
+ name: String,
+ function: SQLFunction,
+ input: Seq[Expression]): LogicalPlan = {
+ def metaForFuncInputAlias = {
+ new MetadataBuilder()
+ .putString("__funcInputAlias", "true")
+ .build()
+ }
+ assert(!function.isTableFunc)
+ val funcName = function.name.funcName
+
+ // Use captured SQL configs when parsing a SQL function.
+ val conf = new SQLConf()
+ function.getSQLConfigs.foreach { case (k, v) => conf.settings.put(k, v) }
+ SQLConf.withExistingConf(conf) {
+ val inputParam = function.inputParam
+ val returnType = function.getScalarFuncReturnType
+ val (expression, query) = function.getExpressionAndQuery(parser,
isTableFunc = false)
+ assert(expression.isDefined || query.isDefined)
+
+ // Check function arguments
+ val paramSize = inputParam.map(_.size).getOrElse(0)
+ if (input.size > paramSize) {
+ throw QueryCompilationErrors.wrongNumArgsError(
+ name, paramSize.toString, input.size)
+ }
+
+ val inputs = inputParam.map { param =>
+ // Attributes referencing the input parameters inside the function can
use the
+ // function name as a qualifier. E.G.:
+ // `create function foo(a int) returns int return foo.a`
+ val qualifier = Seq(funcName)
+ val paddedInput = input ++
+ param.takeRight(paramSize - input.size).map { p =>
+ val defaultExpr = p.getDefault()
+ if (defaultExpr.isDefined) {
+ Cast(parseDefault(defaultExpr.get, parser), p.dataType)
+ } else {
+ throw QueryCompilationErrors.wrongNumArgsError(
+ name, paramSize.toString, input.size)
+ }
+ }
+
+ paddedInput.zip(param.fields).map {
+ case (expr, param) =>
+ Alias(Cast(expr, param.dataType), param.name)(
+ qualifier = qualifier,
+ // mark the alias as function input
+ explicitMetadata = Some(metaForFuncInputAlias))
+ }
+ }.getOrElse(Nil)
+
+ val body = if (query.isDefined) ScalarSubquery(query.get) else
expression.get
+ Project(Alias(Cast(body, returnType), funcName)() :: Nil,
+ Project(inputs, LocalRelation(inputs.flatMap(_.references))))
+ }
+ }
+
/**
* Constructs a [[TableFunctionBuilder]] based on the provided class that
represents a function.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index b00cae22cf9c..a76ca7b15c27 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -45,6 +45,14 @@ trait UserDefinedFunction {
*/
def properties: Map[String, String]
+ /**
+ * Get SQL configs from the function properties.
+ * Use this to restore the SQL configs that should be used for this function.
+ */
+ def getSQLConfigs: Map[String, String] = {
+ UserDefinedFunction.propertiesToSQLConfigs(properties)
+ }
+
/**
* Owner of the function
*/
@@ -142,4 +150,17 @@ object UserDefinedFunction {
* Verify if the function is a [[UserDefinedFunction]].
*/
def isUserDefinedFunction(className: String): Boolean =
SQLFunction.isSQLFunction(className)
+
+ /**
+ * Covert properties to SQL configs.
+ */
+ def propertiesToSQLConfigs(properties: Map[String, String]): Map[String,
String] = {
+ try {
+ for ((key, value) <- properties if key.startsWith(SQL_CONFIG_PREFIX))
+ yield (key.substring(SQL_CONFIG_PREFIX.length), value)
+ } catch {
+ case e: Exception => throw SparkException.internalError(
+ "Corrupted user defined function SQL configs in catalog", cause = e)
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala
new file mode 100644
index 000000000000..d9da38b4c2af
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSQLFunctionNode.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.analysis.{SQLFunctionExpression,
SQLFunctionNode, SQLScalarFunction, SQLTableFunction}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * This rule removes [[SQLScalarFunction]] and [[SQLFunctionNode]] wrapper.
They are respected
+ * till the end of analysis stage because we want to see which part of an
analyzed logical
+ * plan is generated from a SQL function and also perform ACL checks.
+ */
+object EliminateSQLFunctionNode extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ // Include subqueries when eliminating SQL function expressions otherwise
we might miss
+ // expressions in subqueries which can be inlined by the rule
`OptimizeOneRowRelationSubquery`.
+ plan.transformWithSubqueries {
+ case SQLFunctionNode(_, child) => child
+ case f: SQLTableFunction =>
+ throw SparkException.internalError(
+ s"SQL table function plan should be rewritten during analysis: $f")
+ case p: LogicalPlan => p.transformExpressions {
+ case f: SQLScalarFunction => f.child
+ case f: SQLFunctionExpression =>
+ throw SparkException.internalError(
+ s"SQL function expression should be rewritten during analysis: $f")
+ }
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8ee2226947ec..9d269f37e58b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -315,6 +315,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateSubqueryAliases,
EliminatePipeOperators,
EliminateView,
+ EliminateSQLFunctionNode,
ReplaceExpressions,
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b56085ecae8d..9856a26346f6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -93,6 +93,7 @@ object TreePattern extends Enumeration {
val SESSION_WINDOW: Value = Value
val SORT: Value = Value
val SQL_FUNCTION_EXPRESSION: Value = Value
+ val SQL_SCALAR_FUNCTION: Value = Value
val SQL_TABLE_FUNCTION: Value = Value
val SUBQUERY_ALIAS: Value = Value
val SUM: Value = Value
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out
new file mode 100644
index 000000000000..b3c10e929f29
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-udf.sql.out
@@ -0,0 +1,575 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE FUNCTION foo1a0() RETURNS INT RETURN 1
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo1a0`"
+ }
+}
+
+
+-- !query
+SELECT foo1a0()
+-- !query analysis
+Project [spark_catalog.default.foo1a0() AS spark_catalog.default.foo1a0()#x]
++- Project
+ +- OneRowRelation
+
+
+-- !query
+SELECT foo1a0(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ "sqlState" : "42605",
+ "messageParameters" : {
+ "actualNum" : "1",
+ "docroot" : "https://spark.apache.org/docs/latest",
+ "expectedNum" : "0",
+ "functionName" : "`spark_catalog`.`default`.`foo1a0`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 16,
+ "fragment" : "foo1a0(1)"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo1a1`"
+ }
+}
+
+
+-- !query
+SELECT foo1a1(1)
+-- !query analysis
+Project [spark_catalog.default.foo1a1(a#x) AS
spark_catalog.default.foo1a1(1)#x]
++- Project [cast(1 as int) AS a#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT foo1a1(1, 2)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ "sqlState" : "42605",
+ "messageParameters" : {
+ "actualNum" : "2",
+ "docroot" : "https://spark.apache.org/docs/latest",
+ "expectedNum" : "1",
+ "functionName" : "`spark_catalog`.`default`.`foo1a1`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "foo1a1(1, 2)"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo1a2`"
+ }
+}
+
+
+-- !query
+SELECT foo1a2(1, 2, 3, 4)
+-- !query analysis
+Project [spark_catalog.default.foo1a2(a#x, b#x, c#x, d#x) AS
spark_catalog.default.foo1a2(1, 2, 3, 4)#x]
++- Project [cast(1 as int) AS a#x, cast(2 as int) AS b#x, cast(3 as int) AS
c#x, cast(4 as int) AS d#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_1a`"
+ }
+}
+
+
+-- !query
+SELECT foo2_1a(5)
+-- !query analysis
+Project [spark_catalog.default.foo2_1a(a#x) AS
spark_catalog.default.foo2_1a(5)#x]
++- Project [cast(5 as int) AS a#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_1b`"
+ }
+}
+
+
+-- !query
+SELECT foo2_1b(5, 6)
+-- !query analysis
+Project [spark_catalog.default.foo2_1b(a#x, b#x) AS
spark_catalog.default.foo2_1b(5, 6)#x]
++- Project [cast(5 as int) AS a#x, cast(6 as int) AS b#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 *
(a -b)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_1c`"
+ }
+}
+
+
+-- !query
+SELECT foo2_1c(5, 6)
+-- !query analysis
+Project [spark_catalog.default.foo2_1c(a#x, b#x) AS
spark_catalog.default.foo2_1c(5, 6)#x]
++- Project [cast(5 as int) AS a#x, cast(6 as int) AS b#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) -
LENGTH(CAST(b AS VARCHAR(10)))
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_1d`"
+ }
+}
+
+
+-- !query
+SELECT foo2_1d(-5, 6)
+-- !query analysis
+Project [spark_catalog.default.foo2_1d(a#x, b#x) AS
spark_catalog.default.foo2_1d(-5, 6)#x]
++- Project [cast(-5 as int) AS a#x, cast(6 as int) AS b#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_2a`"
+ }
+}
+
+
+-- !query
+SELECT foo2_2a(5)
+-- !query analysis
+Project [spark_catalog.default.foo2_2a(a#x) AS
spark_catalog.default.foo2_2a(5)#x]
++- Project [cast(5 as int) AS a#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_2b`"
+ }
+}
+
+
+-- !query
+SELECT foo2_2b(5)
+-- !query analysis
+Project [spark_catalog.default.foo2_2b(a#x) AS
spark_catalog.default.foo2_2b(5)#x]
+: +- Project [outer(a#x)]
+: +- OneRowRelation
++- Project [cast(5 as int) AS a#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`a`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 21,
+ "stopIndex" : 21,
+ "fragment" : "a"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT
(SELECT a))))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`a`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 37,
+ "stopIndex" : 37,
+ "fragment" : "a"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2
+UNION ALL
+SELECT a + 1 FROM (VALUES 1) AS V(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_2e`"
+ }
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+EXCEPT
+SELECT a + 1 FROM (VALUES 1) AS V(a)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_2f`"
+ }
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+INTERSECT
+SELECT a FROM (VALUES 1) AS V(a)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_2g`"
+ }
+}
+
+
+-- !query
+DROP TABLE IF EXISTS t1
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t1
+
+
+-- !query
+DROP TABLE IF EXISTS t2
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2
+
+
+-- !query
+DROP TABLE IF EXISTS ts
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.ts
+
+
+-- !query
+DROP TABLE IF EXISTS tm
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tm
+
+
+-- !query
+DROP TABLE IF EXISTS ta
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.ta
+
+
+-- !query
+DROP TABLE IF EXISTS V1
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.V1
+
+
+-- !query
+DROP TABLE IF EXISTS V2
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.V2
+
+
+-- !query
+DROP VIEW IF EXISTS t1
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`t1`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS t2
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`t2`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS ts
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`ts`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS tm
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`tm`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS ta
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`ta`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS V1
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`V1`, true, true, false
+
+
+-- !query
+DROP VIEW IF EXISTS V2
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`V2`, true, true, false
+
+
+-- !query
+CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_3`"
+ }
+}
+
+
+-- !query
+CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`V1`, [(c1,None), (c2,None)],
VALUES (1, 2), (3, 4), (5, 6), false, false, PersistedView, COMPENSATION, true
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`V2`, [(c1,None), (c2,None)],
VALUES (-1, -2), (-3, -4), (-5, -6), false, false, PersistedView, COMPENSATION,
true
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM
V1 ORDER BY 1, 2, 3
+-- !query analysis
+Sort [spark_catalog.default.foo2_3(c1, c2)#x ASC NULLS FIRST,
spark_catalog.default.foo2_3(c2, 1)#x ASC NULLS FIRST,
(spark_catalog.default.foo2_3(c1, c2) - spark_catalog.default.foo2_3(c2, (c1 -
1)))#x ASC NULLS FIRST], true
++- Project [spark_catalog.default.foo2_3(a#x, b#x) AS
spark_catalog.default.foo2_3(c1, c2)#x, spark_catalog.default.foo2_3(a#x, b#x)
AS spark_catalog.default.foo2_3(c2, 1)#x, (spark_catalog.default.foo2_3(a#x,
b#x) - spark_catalog.default.foo2_3(a#x, b#x)) AS
(spark_catalog.default.foo2_3(c1, c2) - spark_catalog.default.foo2_3(c2, (c1 -
1)))#x]
+ +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x,
cast(c2#x as int) AS a#x, cast(1 as int) AS b#x, cast(c1#x as int) AS a#x,
cast(c2#x as int) AS b#x, cast(c2#x as int) AS a#x, cast((c1#x - 1) as int) AS
b#x]
+ +- SubqueryAlias spark_catalog.default.v1
+ +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x])
+ +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS
c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8
+-- !query analysis
+Project [c1#x, c2#x]
++- Project [c1#x, c2#x]
+ +- Filter ((spark_catalog.default.foo2_3(a#x, b#x) = c1#x) AND
(spark_catalog.default.foo2_3(a#x, b#x) < 8))
+ +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(0 as int) AS b#x,
cast(c1#x as int) AS a#x, cast(c2#x as int) AS b#x]
+ +- SubqueryAlias spark_catalog.default.v1
+ +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x])
+ +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS
c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) +
foo2_3(c2, c1) - foo2_3(c2, c1))
+FROM V1
+-- !query analysis
+Project [spark_catalog.default.foo2_3(a#x, b#x) AS
spark_catalog.default.foo2_3(sum(c1), sum(c2))#x, (sum(c1) + sum(c2))#xL,
sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2,
c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL]
++- Project [sum(c1)#xL, sum(c2)#xL, (sum(c1) + sum(c2))#xL,
sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2,
c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL, cast(sum(c1)#xL as int) AS
a#x, cast(sum(c2)#xL as int) AS b#x]
+ +- Aggregate [sum(c1#x) AS sum(c1)#xL, sum(c2#x) AS sum(c2)#xL, (sum(c1#x)
+ sum(c2#x)) AS (sum(c1) + sum(c2))#xL, sum(((spark_catalog.default.foo2_3(a#x,
b#x) + spark_catalog.default.foo2_3(a#x, b#x)) -
spark_catalog.default.foo2_3(a#x, b#x))) AS
sum(((spark_catalog.default.foo2_3(c1, c2) + spark_catalog.default.foo2_3(c2,
c1)) - spark_catalog.default.foo2_3(c2, c1)))#xL]
+ +- Project [c1#x, c2#x, cast(c1#x as int) AS a#x, cast(c2#x as int) AS
b#x, cast(c2#x as int) AS a#x, cast(c1#x as int) AS b#x, cast(c2#x as int) AS
a#x, cast(c1#x as int) AS b#x]
+ +- SubqueryAlias spark_catalog.default.v1
+ +- View (`spark_catalog`.`default`.`v1`, [c1#x, c2#x])
+ +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS
c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+CREATE FUNCTION foo2_4a(a ARRAY<STRING>) RETURNS STRING RETURN
+SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1,
'b', 2) rank)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_4a`"
+ }
+}
+
+
+-- !query
+SELECT foo2_4a(ARRAY('a', 'b'))
+-- !query analysis
+Project [spark_catalog.default.foo2_4a(a#x) AS
spark_catalog.default.foo2_4a(array(a, b))#x]
+: +- Project [array_sort(outer(a#x), lambdafunction((rank#x[lambda i#x] -
rank#x[lambda j#x]), lambda i#x, lambda j#x, false), false)[0] AS
array_sort(outer(foo2_4a.a), lambdafunction((rank[namedlambdavariable()] -
rank[namedlambdavariable()]), namedlambdavariable(),
namedlambdavariable()))[0]#x]
+: +- SubqueryAlias __auto_generated_subquery_name
+: +- Project [map(a, 1, b, 2) AS rank#x]
+: +- OneRowRelation
++- Project [cast(array(a, b) as array<string>) AS a#x]
+ +- OneRowRelation
+
+
+-- !query
+CREATE FUNCTION foo2_4b(m MAP<STRING, STRING>, k STRING) RETURNS STRING RETURN
+SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v)
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
+{
+ "errorClass" : "ROUTINE_ALREADY_EXISTS",
+ "sqlState" : "42723",
+ "messageParameters" : {
+ "existingRoutineType" : "routine",
+ "newRoutineType" : "routine",
+ "routineName" : "`default`.`foo2_4b`"
+ }
+}
+
+
+-- !query
+SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a')
+-- !query analysis
+Project [spark_catalog.default.foo2_4b(m#x, k#x) AS
spark_catalog.default.foo2_4b(map(a, hello, b, world), a)#x]
+: +- Project [concat(concat(v#x, ), v#x) AS concat(concat(v, ), v)#x]
+: +- SubqueryAlias __auto_generated_subquery_name
+: +- Project [upper(outer(m#x)[outer(k#x)]) AS v#x]
+: +- OneRowRelation
++- Project [cast(map(a, hello, b, world) as map<string,string>) AS m#x, cast(a
as string) AS k#x]
+ +- OneRowRelation
+
+
+-- !query
+DROP VIEW V2
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`V2`, false, true, false
+
+
+-- !query
+DROP VIEW V1
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`V1`, false, true, false
diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql
b/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql
new file mode 100644
index 000000000000..34cb41d72676
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/sql-udf.sql
@@ -0,0 +1,122 @@
+-- test cases for SQL User Defined Functions
+
+-- 1. CREATE FUNCTION
+-- 1.1 Parameter
+-- 1.1.a A scalar function with various numbers of parameter
+-- Expect success
+CREATE FUNCTION foo1a0() RETURNS INT RETURN 1;
+-- Expect: 1
+SELECT foo1a0();
+-- Expect failure
+SELECT foo1a0(1);
+
+CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1;
+-- Expect: 1
+SELECT foo1a1(1);
+-- Expect failure
+SELECT foo1a1(1, 2);
+
+CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1;
+-- Expect: 1
+SELECT foo1a2(1, 2, 3, 4);
+
+-------------------------------
+-- 2. Scalar SQL UDF
+-- 2.1 deterministic simple expressions
+CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a;
+SELECT foo2_1a(5);
+
+CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b;
+SELECT foo2_1b(5, 6);
+
+CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 *
(a -b);
+SELECT foo2_1c(5, 6);
+
+CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) -
LENGTH(CAST(b AS VARCHAR(10)));
+SELECT foo2_1d(-5, 6);
+
+-- 2.2 deterministic complex expression with subqueries
+-- 2.2.1 Nested Scalar subqueries
+CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a;
+SELECT foo2_2a(5);
+
+CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a);
+SELECT foo2_2b(5);
+
+-- Expect error: deep correlation is not yet supported
+CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a));
+-- SELECT foo2_2c(5);
+
+-- Expect error: deep correlation is not yet supported
+CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT
(SELECT a))));
+-- SELECT foo2_2d(5);
+
+-- 2.2.2 Set operations
+-- Expect error: correlated scalar subquery must be aggregated.
+CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2
+UNION ALL
+SELECT a + 1 FROM (VALUES 1) AS V(c1);
+-- SELECT foo2_2e(5);
+
+-- Expect error: correlated scalar subquery must be aggregated.
+CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+EXCEPT
+SELECT a + 1 FROM (VALUES 1) AS V(a);
+-- SELECT foo2_2f(5);
+
+-- Expect error: correlated scalar subquery must be aggregated.
+CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+INTERSECT
+SELECT a FROM (VALUES 1) AS V(a);
+-- SELECT foo2_2g(5);
+
+-- Prepare by dropping views or tables if they already exist.
+DROP TABLE IF EXISTS t1;
+DROP TABLE IF EXISTS t2;
+DROP TABLE IF EXISTS ts;
+DROP TABLE IF EXISTS tm;
+DROP TABLE IF EXISTS ta;
+DROP TABLE IF EXISTS V1;
+DROP TABLE IF EXISTS V2;
+DROP VIEW IF EXISTS t1;
+DROP VIEW IF EXISTS t2;
+DROP VIEW IF EXISTS ts;
+DROP VIEW IF EXISTS tm;
+DROP VIEW IF EXISTS ta;
+DROP VIEW IF EXISTS V1;
+DROP VIEW IF EXISTS V2;
+
+-- 2.3 Calling Scalar UDF from various places
+CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b;
+CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6);
+CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6);
+
+-- 2.3.1 Multiple times in the select list
+SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM
V1 ORDER BY 1, 2, 3;
+
+-- 2.3.2 In the WHERE clause
+SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8;
+
+-- 2.3.3 Different places around an aggregate
+SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) +
foo2_3(c2, c1) - foo2_3(c2, c1))
+FROM V1;
+
+-- 2.4 Scalar UDF with complex one row relation subquery
+-- 2.4.1 higher order functions
+CREATE FUNCTION foo2_4a(a ARRAY<STRING>) RETURNS STRING RETURN
+SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1,
'b', 2) rank);
+
+SELECT foo2_4a(ARRAY('a', 'b'));
+
+-- 2.4.2 built-in functions
+CREATE FUNCTION foo2_4b(m MAP<STRING, STRING>, k STRING) RETURNS STRING RETURN
+SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v);
+
+SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a');
+
+-- Clean up
+DROP VIEW V2;
+DROP VIEW V1;
diff --git a/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out
b/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out
new file mode 100644
index 000000000000..9f7af7c64487
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/sql-udf.sql.out
@@ -0,0 +1,484 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+CREATE FUNCTION foo1a0() RETURNS INT RETURN 1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo1a0()
+-- !query schema
+struct<spark_catalog.default.foo1a0():int>
+-- !query output
+1
+
+
+-- !query
+SELECT foo1a0(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ "sqlState" : "42605",
+ "messageParameters" : {
+ "actualNum" : "1",
+ "docroot" : "https://spark.apache.org/docs/latest",
+ "expectedNum" : "0",
+ "functionName" : "`spark_catalog`.`default`.`foo1a0`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 16,
+ "fragment" : "foo1a0(1)"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo1a1(a INT) RETURNS INT RETURN 1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo1a1(1)
+-- !query schema
+struct<spark_catalog.default.foo1a1(1):int>
+-- !query output
+1
+
+
+-- !query
+SELECT foo1a1(1, 2)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ "sqlState" : "42605",
+ "messageParameters" : {
+ "actualNum" : "2",
+ "docroot" : "https://spark.apache.org/docs/latest",
+ "expectedNum" : "1",
+ "functionName" : "`spark_catalog`.`default`.`foo1a1`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 19,
+ "fragment" : "foo1a1(1, 2)"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo1a2(a INT, b INT, c INT, d INT) RETURNS INT RETURN 1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo1a2(1, 2, 3, 4)
+-- !query schema
+struct<spark_catalog.default.foo1a2(1, 2, 3, 4):int>
+-- !query output
+1
+
+
+-- !query
+CREATE FUNCTION foo2_1a(a INT) RETURNS INT RETURN a
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_1a(5)
+-- !query schema
+struct<spark_catalog.default.foo2_1a(5):int>
+-- !query output
+5
+
+
+-- !query
+CREATE FUNCTION foo2_1b(a INT, b INT) RETURNS INT RETURN a + b
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_1b(5, 6)
+-- !query schema
+struct<spark_catalog.default.foo2_1b(5, 6):int>
+-- !query output
+11
+
+
+-- !query
+CREATE FUNCTION foo2_1c(a INT, b INT) RETURNS INT RETURN 10 * (a + b) + 100 *
(a -b)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_1c(5, 6)
+-- !query schema
+struct<spark_catalog.default.foo2_1c(5, 6):int>
+-- !query output
+10
+
+
+-- !query
+CREATE FUNCTION foo2_1d(a INT, b INT) RETURNS INT RETURN ABS(a) -
LENGTH(CAST(b AS VARCHAR(10)))
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_1d(-5, 6)
+-- !query schema
+struct<spark_catalog.default.foo2_1d(-5, 6):int>
+-- !query output
+4
+
+
+-- !query
+CREATE FUNCTION foo2_2a(a INT) RETURNS INT RETURN SELECT a
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_2a(5)
+-- !query schema
+struct<spark_catalog.default.foo2_2a(5):int>
+-- !query output
+5
+
+
+-- !query
+CREATE FUNCTION foo2_2b(a INT) RETURNS INT RETURN 1 + (SELECT a)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_2b(5)
+-- !query schema
+struct<spark_catalog.default.foo2_2b(5):int>
+-- !query output
+6
+
+
+-- !query
+CREATE FUNCTION foo2_2c(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT a))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`a`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 21,
+ "stopIndex" : 21,
+ "fragment" : "a"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2d(a INT) RETURNS INT RETURN 1 + (SELECT (SELECT (SELECT
(SELECT a))))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`a`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 37,
+ "stopIndex" : 37,
+ "fragment" : "a"
+ } ]
+}
+
+
+-- !query
+CREATE FUNCTION foo2_2e(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1) WHERE c1 = 2
+UNION ALL
+SELECT a + 1 FROM (VALUES 1) AS V(c1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE FUNCTION foo2_2f(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+EXCEPT
+SELECT a + 1 FROM (VALUES 1) AS V(a)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE FUNCTION foo2_2g(a INT) RETURNS INT RETURN
+SELECT a FROM (VALUES 1) AS V(c1)
+INTERSECT
+SELECT a FROM (VALUES 1) AS V(a)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS t1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS t2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS ts
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS tm
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS ta
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS V1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS V2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS t1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS t2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS ts
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS tm
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS ta
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS V1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IF EXISTS V2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE FUNCTION foo2_3(a INT, b INT) RETURNS INT RETURN a + b
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE VIEW V1(c1, c2) AS VALUES (1, 2), (3, 4), (5, 6)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE VIEW V2(c1, c2) AS VALUES (-1, -2), (-3, -4), (-5, -6)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_3(c1, c2), foo2_3(c2, 1), foo2_3(c1, c2) - foo2_3(c2, c1 - 1) FROM
V1 ORDER BY 1, 2, 3
+-- !query schema
+struct<spark_catalog.default.foo2_3(c1,
c2):int,spark_catalog.default.foo2_3(c2,
1):int,(spark_catalog.default.foo2_3(c1, c2) - spark_catalog.default.foo2_3(c2,
(c1 - 1))):int>
+-- !query output
+3 3 1
+7 5 1
+11 7 1
+
+
+-- !query
+SELECT * FROM V1 WHERE foo2_3(c1, 0) = c1 AND foo2_3(c1, c2) < 8
+-- !query schema
+struct<c1:int,c2:int>
+-- !query output
+1 2
+3 4
+
+
+-- !query
+SELECT foo2_3(SUM(c1), SUM(c2)), SUM(c1) + SUM(c2), SUM(foo2_3(c1, c2) +
foo2_3(c2, c1) - foo2_3(c2, c1))
+FROM V1
+-- !query schema
+struct<spark_catalog.default.foo2_3(sum(c1), sum(c2)):int,(sum(c1) +
sum(c2)):bigint,sum(((spark_catalog.default.foo2_3(c1, c2) +
spark_catalog.default.foo2_3(c2, c1)) - spark_catalog.default.foo2_3(c2,
c1))):bigint>
+-- !query output
+21 21 21
+
+
+-- !query
+CREATE FUNCTION foo2_4a(a ARRAY<STRING>) RETURNS STRING RETURN
+SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] FROM (SELECT MAP('a', 1,
'b', 2) rank)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_4a(ARRAY('a', 'b'))
+-- !query schema
+struct<spark_catalog.default.foo2_4a(array(a, b)):string>
+-- !query output
+a
+
+
+-- !query
+CREATE FUNCTION foo2_4b(m MAP<STRING, STRING>, k STRING) RETURNS STRING RETURN
+SELECT v || ' ' || v FROM (SELECT upper(m[k]) AS v)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT foo2_4b(map('a', 'hello', 'b', 'world'), 'a')
+-- !query schema
+struct<spark_catalog.default.foo2_4b(map(a, hello, b, world), a):string>
+-- !query output
+HELLO HELLO
+
+
+-- !query
+DROP VIEW V2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW V1
+-- !query schema
+struct<>
+-- !query output
+
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala
new file mode 100644
index 000000000000..4da3b9ab1d06
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLFunctionSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Test suite for SQL user-defined functions (UDFs).
+ */
+class SQLFunctionSuite extends QueryTest with SharedSparkSession {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+ }
+
+ test("SQL scalar function") {
+ withUserDefinedFunction("area" -> false) {
+ sql(
+ """
+ |CREATE FUNCTION area(width DOUBLE, height DOUBLE)
+ |RETURNS DOUBLE
+ |RETURN width * height
+ |""".stripMargin)
+ checkAnswer(sql("SELECT area(1, 2)"), Row(2))
+ checkAnswer(sql("SELECT area(a, b) FROM t"), Seq(Row(0), Row(2)))
+ }
+ }
+
+ test("SQL scalar function with subquery in the function body") {
+ withUserDefinedFunction("foo" -> false) {
+ withTable("tbl") {
+ sql("CREATE TABLE tbl AS SELECT * FROM VALUES (1, 2), (1, 3), (2, 3)
t(a, b)")
+ sql(
+ """
+ |CREATE FUNCTION foo(x INT) RETURNS INT
+ |RETURN SELECT SUM(b) FROM tbl WHERE x = a;
+ |""".stripMargin)
+ checkAnswer(sql("SELECT foo(1)"), Row(5))
+ checkAnswer(sql("SELECT foo(a) FROM t"), Seq(Row(null), Row(5)))
+ }
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index c00f00ceaa35..a7af22a0554e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -79,7 +79,8 @@ class ExpressionInfoSuite extends SparkFunSuite with
SharedSparkSession {
assert(info.getSource === "built-in")
val validSources = Seq(
- "built-in", "hive", "python_udf", "scala_udf", "java_udf",
"python_udtf", "internal")
+ "built-in", "hive", "python_udf", "scala_udf", "java_udf",
"python_udtf", "internal",
+ "sql_udf")
validSources.foreach { source =>
val info = new ExpressionInfo(
"testClass", null, "testName", null, "", "", "", "", "", "", source)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]