This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 57503b67d939 [SPARK-55528][SQL] Add default collation support for SQL
UDFs
57503b67d939 is described below
commit 57503b67d93971f22839c11086ad79191d750eee
Author: ilicmarkodb <[email protected]>
AuthorDate: Wed Mar 4 20:50:59 2026 +0800
[SPARK-55528][SQL] Add default collation support for SQL UDFs
### What changes were proposed in this pull request?
This PR adds default collation support for SQL user-defined functions,
enabling UDFs to inherit schema-level collations and specify explicit default
collations via the `DEFAULT COLLATION` clause.
**How DEFAULT COLLATION is applied:**
- **STRING parameters**: Parameters declared as `STRING` without explicit
collation (e.g., `p1 STRING`) receive the default collation
- **STRING return type**: When `RETURNS STRING` is specified without
explicit collation, the default collation is applied
- **Free string literals in body**: String literals in the UDF body receive
the default collation
- **Default string producing built-in functions in body**: String-producing
built-in functions (e.g., `current_database()`) in the UDF body use the default
collation for their string outputs
Note: Explicit collations always take precedence. For example, `p1 STRING
COLLATE UTF8_BINARY` preserves `UTF8_BINARY` regardless of the default
collation.
### Why are the changes needed?
Currently, SQL UDFs in Spark don't support collation specifications. This
PR enables:
- UDFs to specify `DEFAULT COLLATION` clause in `CREATE FUNCTION` statements
- UDFs to automatically inherit the schema's default collation when not
explicitly specified
- Proper handling of explicit collations (e.g., `STRING COLLATE
UTF8_BINARY`) without override
- Collation support for table function return columns
### Does this PR introduce any user-facing change?
Yes. Users can now:
- Use `DEFAULT COLLATION <collation_name>` in `CREATE FUNCTION` statements
- Have UDFs automatically inherit the schema's default collation
Example:
```sql
-- UDF with explicit default collation
CREATE FUNCTION my_func(p1 STRING)
RETURNS STRING
DEFAULT COLLATION UTF8_LCASE
RETURN SELECT upper(p1);
-- String literals and return type get UTF8_LCASE
-- p1 parameter gets UTF8_LCASE (no explicit collation specified)
```
```sql
-- Explicit collation overrides default
CREATE FUNCTION my_func2(p1 STRING COLLATE UTF8_BINARY)
RETURNS STRING COLLATE de
DEFAULT COLLATION UTF8_LCASE
RETURN SELECT p1 || 'suffix';
-- p1 keeps UTF8_BINARY (explicit collation specified)
-- return type is 'de' (explicit collation specified)
-- 'suffix' literal gets UTF8_LCASE (default applies)
```
### How was this patch tested?
New tests in `DefaultCollationTestSuite`.
### Was this patch authored or co-authored using generative AI tooling?
Yes, co-authored with Claude Sonnet 4.5
Closes #54324 from ilicmarkodb/udf-default-collation.
Authored-by: ilicmarkodb <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 1 +
.../org/apache/spark/sql/types/StringType.scala | 37 ++
.../sql/connector/catalog/FunctionCatalog.java | 5 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 20 +-
.../catalyst/analysis/ApplyDefaultCollation.scala | 13 +-
.../catalyst/analysis/CollationTypeCoercion.scala | 3 +
.../sql/catalyst/analysis/ResolveCatalogs.scala | 2 +-
.../catalyst/analysis/SQLFunctionExpression.scala | 7 +
.../spark/sql/catalyst/catalog/SQLFunction.scala | 27 +-
.../sql/catalyst/catalog/UserDefinedFunction.scala | 29 +-
.../sql/catalyst/plans/logical/v2Commands.scala | 1 +
.../sql/catalyst/catalog/SessionCatalogSuite.scala | 2 +
.../catalyst/analysis/ResolveSessionCatalog.scala | 5 +-
.../spark/sql/execution/SparkSqlParser.scala | 10 +-
.../command/CreateSQLFunctionCommand.scala | 10 +-
.../command/CreateUserDefinedFunctionCommand.scala | 2 +
.../sql/collation/DefaultCollationTestSuite.scala | 521 ++++++++++++++++++++-
.../command/CreateSQLFunctionParserSuite.scala | 2 +
18 files changed, 664 insertions(+), 33 deletions(-)
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index d61dd137ec5a..24a6fb7e6d98 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -1571,6 +1571,7 @@ routineCharacteristics
| sqlDataAccess
| nullCall
| commentSpec
+ | collationSpec
| rightsClause)*
;
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 9f52f647a57a..34467c258d6c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -190,3 +190,40 @@ case object NoConstraint extends StringConstraint
case class FixedLength(length: Int) extends StringConstraint
case class MaxLength(length: Int) extends StringConstraint
+
+/**
+ * Used in the context of UDFs when resolving parameters/return types.
+ *
+ * For example, if a UDF parameter is defined as `p1 STRING COLLATE
UTF8_BINARY`, calling
+ * [[typeName]] will return just `STRING`, omitting the collation information.
This causes the
+ * parameter to be parsed into the companion object [[StringType]]. If the UDF
has a default
+ * collation specified, it will be applied to the companion object
[[StringType]], potentially
+ * resulting in the construction of a [[StringType]] with an invalid collation.
+ */
+object ExplicitUTF8BinaryStringType
+ extends StringType(CollationFactory.UTF8_BINARY_COLLATION_ID,
NoConstraint) {
+ override def typeName: String = s"string collate $collationName"
+ override def toString: String = s"StringType($collationName)"
+
+ /**
+ * Transforms the given `dataType` by replacing each [[StringType]] that has
an explicit
+ * `UTF8_BINARY` collation with `ExplicitUTF8BinaryStringType`.
+ */
+ def transform(dataType: DataType): DataType = {
+ dataType.transformRecursively {
+ case st: StringType if st.isUTF8BinaryCollation && !st.eq(StringType) =>
+ ExplicitUTF8BinaryStringType
+ }
+ }
+
+ /**
+ * Transforms the given `dataType` by replacing each companion object
[[StringType]] with
+ * explicit `UTF8_BINARY` [[StringType]].
+ */
+ def transformDefaultStringType(dataType: DataType): DataType = {
+ dataType.transformRecursively {
+ case st: StringType if st.eq(StringType) =>
+ StringType(CollationFactory.UTF8_BINARY_COLLATION_ID)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
index de4559011942..09878509da9d 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/FunctionCatalog.java
@@ -30,6 +30,11 @@ import
org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
@Evolving
public interface FunctionCatalog extends CatalogPlugin {
+ /**
+ * A reserved property to specify the collation of the function.
+ */
+ String PROP_COLLATION = "collation";
+
/**
* List the functions in a namespace from the catalog.
* <p>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dd86c6c52cb9..12fc0f0a09fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -216,6 +216,12 @@ object AnalysisContext {
try f finally { set(originContext) }
}
+ def withAnalysisContext[A](function: SQLFunction)(f: => A): A = {
+ val originContext = value.get()
+ val context = originContext.copy(collation = function.collation)
+ set(context)
+ try f finally { set(originContext) }
+ }
def withNewAnalysisContext[A](f: => A): A = {
val originContext = value.get()
@@ -2340,8 +2346,10 @@ class Analyzer(
e: SubqueryExpression,
outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression):
SubqueryExpression = {
- val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
- executeSameContext(e.plan)
+ val newSubqueryPlan = SQLFunctionContext.withNewContext {
+ AnalysisContext.withOuterPlan(outer) {
+ executeSameContext(e.plan)
+ }
}
// If the subquery plan is fully resolved, pull the outer references and
record
@@ -2486,7 +2494,9 @@ class Analyzer(
Analyzer.retainResolutionConfigsForAnalysis(newConf = newConf,
existingConf = conf)
}
SQLConf.withExistingConf(newConf) {
- executeSameContext(plan)
+ AnalysisContext.withAnalysisContext(f.function) {
+ executeSameContext(plan)
+ }
}
}
// Fail the analysis eagerly if a SQL function cannot be resolved using
its input.
@@ -2785,7 +2795,9 @@ class Analyzer(
val resolved = SQLConf.withExistingConf(newConf) {
val plan = v1SessionCatalog.makeSQLTableFunctionPlan(name, function,
inputs, output)
SQLFunctionContext.withSQLFunction {
- executeSameContext(plan)
+ AnalysisContext.withAnalysisContext(function) {
+ executeSameContext(plan)
+ }
}
}
// Remove unnecessary lateral joins that are used to resolve the SQL
function.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
index 3141e71ecadb..67d5b70b30a3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollation.scala
@@ -21,7 +21,7 @@ import scala.util.control.NonFatal
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Cast,
DefaultStringProducingExpression, Expression, Literal, SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns,
AlterColumnSpec, AlterViewAs, ColumnDefinition, CreateTable,
CreateTableAsSelect, CreateTempView, CreateView, LogicalPlan, QualifiedColType,
ReplaceColumns, ReplaceTable, ReplaceTableAsSelect, TableSpec,
V2CreateTablePlan}
+import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumns,
AlterColumnSpec, AlterViewAs, ColumnDefinition, CreateTable,
CreateTableAsSelect, CreateTempView, CreateUserDefinedFunction, CreateView,
LogicalPlan, QualifiedColType, ReplaceColumns, ReplaceTable,
ReplaceTableAsSelect, TableSpec, V2CreateTablePlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils.{areSameBaseType,
isDefaultStringCharOrVarcharType, replaceDefaultStringCharAndVarcharTypes}
@@ -220,6 +220,17 @@ object ApplyDefaultCollation extends Rule[LogicalPlan] {
newAlterViewAs.copyTagsFrom(alterViewAs)
newAlterViewAs
+ case createUserDefinedFunction@CreateUserDefinedFunction(
+ ResolvedIdentifier(catalog: SupportsNamespaces, identifier),
+ _, _, _, _, _, collation, _, _, _, _, _, _) if collation.isEmpty =>
+ val newCreateUserDefinedFunction =
+ CurrentOrigin.withOrigin(createUserDefinedFunction.origin) {
+ createUserDefinedFunction.copy(
+ collation = getCollationFromSchemaMetadata(catalog,
identifier.namespace()))
+ }
+ newCreateUserDefinedFunction.copyTagsFrom(createUserDefinedFunction)
+ newCreateUserDefinedFunction
+
case other =>
other
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
index 8d5b8c590fa4..75619c9c5ce3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
@@ -355,6 +355,9 @@ object CollationTypeCoercion extends SQLConfHelper {
case expr @ (_: NamedExpression | _: SubqueryExpression | _:
VariableReference) =>
Some(addContextToStringType(expr.dataType, Implicit))
+ case f: SQLFunctionExpression =>
+ Some(addContextToStringType(f.dataType, Implicit))
+
case lit: Literal =>
Some(addContextToStringType(lit.dataType, Default))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
index d0a05e0495dc..5fa8ffefc012 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala
@@ -91,7 +91,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
"CREATE", nameParts.last)
case CreateUserDefinedFunction(UnresolvedIdentifier(nameParts, _),
- _, _, _, _, _, _, _, _, _, _, _)
+ _, _, _, _, _, _, _, _, _, _, _, _)
if isSystemBuiltinName(nameParts) =>
throw QueryCompilationErrors.operationNotAllowedOnBuiltinFunctionError(
"CREATE", nameParts.last)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
index 37981f47287d..e7bdc8ec0248 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SQLFunctionExpression.scala
@@ -87,4 +87,11 @@ object SQLFunctionContext {
set(context)
try f finally { set(originContext) }
}
+
+ def withNewContext[A](f: => A): A = {
+ val originContext = value.get()
+ val context = SQLFunctionContext()
+ set(context)
+ try f finally { set(originContext) }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index 84d87fab8b06..5724ce29742d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -29,7 +29,7 @@ import
org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo,
ScalarSubquery}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
OneRowRelation, Project}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, ExplicitUTF8BinaryStringType,
StructType}
/**
* Represent a SQL function.
@@ -40,6 +40,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
* @param exprText function body as an expression
* @param queryText function body as a query
* @param comment function comment
+ * @param collation function default collation
* @param deterministic whether the function is deterministic
* @param containsSQL whether the function has data access routine to be
CONTAINS SQL
* @param isTableFunc whether the function is a table function
@@ -54,6 +55,7 @@ case class SQLFunction(
exprText: Option[String],
queryText: Option[String],
comment: Option[String],
+ collation: Option[String],
deterministic: Option[Boolean],
containsSQL: Option[Boolean],
isTableFunc: Boolean,
@@ -152,16 +154,19 @@ case class SQLFunction(
*/
private def sqlFunctionToProps: Map[String, String] = {
val props = new mutable.HashMap[String, String]
- val inputParamText = inputParam.map(_.fields.map(_.toDDL).mkString(", "))
+ val inputParamText =
inputParam.map(ExplicitUTF8BinaryStringType.transform(_)
+ .asInstanceOf[StructType].fields.map(_.toDDL).mkString(", "))
inputParamText.foreach(props.put(INPUT_PARAM, _))
val returnTypeText = returnType match {
- case Left(dataType) => dataType.sql
- case Right(columns) => columns.toDDL
+ case Left(dataType) =>
ExplicitUTF8BinaryStringType.transform(dataType).sql
+ case Right(columns) =>
+
ExplicitUTF8BinaryStringType.transform(columns).asInstanceOf[StructType].toDDL
}
props.put(RETURN_TYPE, returnTypeText)
exprText.foreach(props.put(EXPRESSION, _))
queryText.foreach(props.put(QUERY, _))
comment.foreach(props.put(COMMENT, _))
+ collation.foreach(props.put(COLLATION, _))
deterministic.foreach(d => props.put(DETERMINISTIC, d.toString))
containsSQL.foreach(x => props.put(CONTAINS_SQL, x.toString))
props.put(IS_TABLE_FUNC, isTableFunc.toString)
@@ -185,6 +190,7 @@ object SQLFunction {
private val EXPRESSION: String = SQL_FUNCTION_PREFIX + "expression"
private val QUERY: String = SQL_FUNCTION_PREFIX + "query"
private val COMMENT: String = SQL_FUNCTION_PREFIX + "comment"
+ private val COLLATION: String = SQL_FUNCTION_PREFIX + "collation"
private val DETERMINISTIC: String = SQL_FUNCTION_PREFIX + "deterministic"
private val CONTAINS_SQL: String = SQL_FUNCTION_PREFIX + "containsSQL"
private val IS_TABLE_FUNC: String = SQL_FUNCTION_PREFIX + "isTableFunc"
@@ -211,14 +217,16 @@ object SQLFunction {
val blob = parts.sortBy(_._1).map(_._2).mkString
val props = mapper.readValue(blob, classOf[Map[String, String]])
val isTableFunc = props(IS_TABLE_FUNC).toBoolean
- val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc,
parser)
+ val collation = props.get(COLLATION)
+ val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc,
parser, collation)
SQLFunction(
name = function.identifier,
- inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser)),
+ inputParam = props.get(INPUT_PARAM).map(parseRoutineParam(_, parser,
collation)),
returnType = returnType.get,
exprText = props.get(EXPRESSION),
queryText = props.get(QUERY),
comment = props.get(COMMENT),
+ collation = collation,
deterministic = props.get(DETERMINISTIC).map(_.toBoolean),
containsSQL = props.get(CONTAINS_SQL).map(_.toBoolean),
isTableFunc = isTableFunc,
@@ -249,7 +257,8 @@ object SQLFunction {
def parseReturnTypeText(
text: String,
isTableFunc: Boolean,
- parser: ParserInterface): Option[Either[DataType, StructType]] = {
+ parser: ParserInterface,
+ collation: Option[String]): Option[Either[DataType, StructType]] = {
if (!isTableFunc) {
// This is a scalar user-defined function.
if (text.isEmpty) {
@@ -257,7 +266,7 @@ object SQLFunction {
Option.empty[Either[DataType, StructType]]
} else {
// The CREATE FUNCTION statement included a RETURNS clause with an
explicit return type.
- Some(Left(parseDataType(text, parser)))
+ Some(Left(parseDataType(text, parser, collation)))
}
} else {
// This is a table function.
@@ -266,7 +275,7 @@ object SQLFunction {
Option.empty[Either[DataType, StructType]]
} else {
// The CREATE FUNCTION statement included a RETURNS TABLE clause with
an explicit schema.
- Some(Right(parseTableSchema(text, parser)))
+ Some(Right(parseTableSchema(text, parser, collation)))
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index 3365b11b0742..4887830c4279 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -25,6 +25,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
+import
org.apache.spark.sql.catalyst.types.DataTypeUtils.replaceDefaultStringCharAndVarcharTypes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.types.{DataType, StructType}
@@ -86,21 +87,37 @@ object UserDefinedFunction {
// The default Hive Metastore SQL schema length for function resource uri.
private val HIVE_FUNCTION_RESOURCE_URI_LENGTH_THRESHOLD: Int = 4000
- def parseRoutineParam(text: String, parser: ParserInterface): StructType = {
- val parsed = parser.parseRoutineParam(text)
+ def parseRoutineParam(text: String, parser: ParserInterface, collation:
Option[String])
+ : StructType = {
+ val parsed = StructType(parser.parseRoutineParam(text)
+ .map(field => field.copy(dataType = resolveReturnType(field.dataType,
collation))))
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
}
- def parseTableSchema(text: String, parser: ParserInterface): StructType = {
- val parsed = parser.parseTableSchema(text)
+ def parseTableSchema(text: String, parser: ParserInterface, collation:
Option[String])
+ : StructType = {
+ val parsed = StructType(parser.parseTableSchema(text)
+ .map(field => field.copy(dataType = resolveReturnType(field.dataType,
collation))))
CharVarcharUtils.failIfHasCharVarchar(parsed).asInstanceOf[StructType]
}
- def parseDataType(text: String, parser: ParserInterface): DataType = {
- val dataType = parser.parseDataType(text)
+ def parseDataType(text: String, parser: ParserInterface, collation:
Option[String]): DataType = {
+ val dataType = resolveReturnType(parser.parseDataType(text), collation)
CharVarcharUtils.failIfHasCharVarchar(dataType)
}
+ /**
+ * Resolve the return type by applying the default collation to non-collated
string, char and
+ * varchar types.
+ *
+ * @param returnType The return type is taken from the RETURNS clause,
+ * or inferred from the function's return value if the
clause is not specified.
+ * @param collation The default collation, if specified; otherwise, None.
+ */
+ def resolveReturnType(returnType: DataType, collation: Option[String]):
DataType = {
+ collation.map(replaceDefaultStringCharAndVarcharTypes(returnType,
_)).getOrElse(returnType)
+ }
+
private val _mapper: ObjectMapper = getObjectMapper
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index e22b55f625be..06a4d85a856c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -1377,6 +1377,7 @@ case class CreateUserDefinedFunction(
exprText: Option[String],
queryText: Option[String],
comment: Option[String],
+ collation: Option[String],
isDeterministic: Option[Boolean],
containsSQL: Option[Boolean],
language: RoutineLanguage,
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 47e0321bdfef..be7b4530e99e 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -2158,6 +2158,7 @@ abstract class SessionCatalogSuite extends AnalysisTest
with Eventually {
exprText = None,
queryText = None,
comment = None,
+ collation = None,
deterministic = Some(true),
containsSQL = Some(false),
isTableFunc = false,
@@ -2181,6 +2182,7 @@ abstract class SessionCatalogSuite extends AnalysisTest
with Eventually {
exprText = Some("SELECT 1"),
queryText = None,
comment = None,
+ collation = None,
deterministic = Some(true),
containsSQL = Some(true),
isTableFunc = true, // But marked as table function
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index 92d208813cb3..7efd2e111317 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -588,7 +588,7 @@ class ResolveSessionCatalog(val catalogManager:
CatalogManager)
throw
QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
case c @ CreateUserDefinedFunction(
- CreateFunctionInSessionCatalog(ident), _, _, _, _, _, _, _, _, _, _,
_) =>
+ CreateFunctionInSessionCatalog(ident), _, _, _, _, _, _, _, _, _, _,
_, _) =>
CreateUserDefinedFunctionCommand(
FunctionIdentifier(ident.table, ident.database, ident.catalog),
c.inputParamText,
@@ -596,6 +596,7 @@ class ResolveSessionCatalog(val catalogManager:
CatalogManager)
c.exprText,
c.queryText,
c.comment,
+ c.collation,
c.isDeterministic,
c.containsSQL,
c.language,
@@ -605,7 +606,7 @@ class ResolveSessionCatalog(val catalogManager:
CatalogManager)
c.replace)
case CreateUserDefinedFunction(
- ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _) =>
+ ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _, _) =>
throw
QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 834dc5035196..4c6df5dbe6cf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -919,7 +919,7 @@ class SparkSqlAstBuilder extends AstBuilder {
val exprText = Option(ctx.expression()).map(source)
val queryText = Option(ctx.query()).map(source)
- val (containsSQL, deterministic, comment, optionalLanguage) =
+ val (containsSQL, deterministic, comment, collation, optionalLanguage) =
visitRoutineCharacteristics(ctx.routineCharacteristics())
val language: RoutineLanguage = optionalLanguage.getOrElse(LanguageSQL)
val isTableFunc = ctx.TABLE() != null ||
returnTypeText.equalsIgnoreCase("table")
@@ -933,6 +933,7 @@ class SparkSqlAstBuilder extends AstBuilder {
exprText,
queryText,
comment,
+ collation,
deterministic,
containsSQL,
language,
@@ -954,6 +955,7 @@ class SparkSqlAstBuilder extends AstBuilder {
exprText,
queryText,
comment,
+ collation,
deterministic,
containsSQL,
language,
@@ -979,7 +981,7 @@ class SparkSqlAstBuilder extends AstBuilder {
* rights: [SQL SECURITY INVOKER | SQL SECURITY DEFINER]
*/
override def visitRoutineCharacteristics(ctx: RoutineCharacteristicsContext)
- : (Option[Boolean], Option[Boolean], Option[String],
Option[RoutineLanguage]) =
+ : (Option[Boolean], Option[Boolean], Option[String], Option[String],
Option[RoutineLanguage]) =
withOrigin(ctx) {
checkDuplicateClauses(ctx.routineLanguage(), "LANGUAGE", ctx)
checkDuplicateClauses(ctx.specificName(), "SPECIFIC", ctx)
@@ -987,6 +989,7 @@ class SparkSqlAstBuilder extends AstBuilder {
checkDuplicateClauses(ctx.nullCall(), "NULL CALL", ctx)
checkDuplicateClauses(ctx.deterministic(), "DETERMINISTIC", ctx)
checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx)
+ checkDuplicateClauses(ctx.collationSpec(), "DEFAULT COLLATION", ctx)
checkDuplicateClauses(ctx.rightsClause(), "SQL SECURITY RIGHTS", ctx)
val language: Option[RoutineLanguage] = ctx
@@ -1004,13 +1007,14 @@ class SparkSqlAstBuilder extends AstBuilder {
val deterministic =
ctx.deterministic().asScala.headOption.map(visitDeterminism)
val comment = visitCommentSpecList(ctx.commentSpec())
+ val collation =
ctx.collationSpec().asScala.headOption.map(visitCollationSpec)
ctx.specificName().asScala.headOption.foreach(checkSpecificName)
ctx.nullCall().asScala.headOption.foreach(checkNullCall)
ctx.rightsClause().asScala.headOption.foreach(checkRightsClause)
val containsSQL: Option[Boolean] =
ctx.sqlDataAccess().asScala.headOption.map(visitDataAccess)
- (containsSQL, deterministic, comment, language)
+ (containsSQL, deterministic, comment, collation, language)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
index eb860089b0c8..730c3030428b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionCommand.scala
@@ -56,6 +56,7 @@ case class CreateSQLFunctionCommand(
exprText: Option[String],
queryText: Option[String],
comment: Option[String],
+ collation: Option[String],
isDeterministic: Option[Boolean],
containsSQL: Option[Boolean],
isTableFunc: Boolean,
@@ -72,8 +73,8 @@ case class CreateSQLFunctionCommand(
val catalog = sparkSession.sessionState.catalog
val conf = sparkSession.sessionState.conf
- val inputParam =
inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser))
- val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser)
+ val inputParam =
inputParamText.map(UserDefinedFunction.parseRoutineParam(_, parser, collation))
+ val returnType = parseReturnTypeText(returnTypeText, isTableFunc, parser,
collation)
val function = SQLFunction(
name,
@@ -82,6 +83,7 @@ case class CreateSQLFunctionCommand(
exprText,
queryText,
comment,
+ collation,
isDeterministic,
containsSQL,
isTableFunc,
@@ -159,7 +161,7 @@ case class CreateSQLFunctionCommand(
val analyzed = analyzer.execute(plan)
val (resolved, resolvedReturnType) = analyzed match {
case p @ Project(expr :: Nil, _) if expr.resolved =>
- (p, Left(expr.dataType))
+ (p, Left(resolveReturnType(expr.dataType, collation)))
case other =>
(other, function.returnType)
}
@@ -211,7 +213,7 @@ case class CreateSQLFunctionCommand(
throw
UserDefinedFunctionErrors.missingColumnNamesForSqlTableUdf(name.funcName)
case _ =>
StructType(analyzed.asInstanceOf[LateralJoin].right.plan.output.map { col =>
- StructField(col.name, col.dataType)
+ StructField(col.name, resolveReturnType(col.dataType,
collation))
})
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
index a3780a8bff19..f65c7c91251a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
@@ -46,6 +46,7 @@ object CreateUserDefinedFunctionCommand {
exprText: Option[String],
queryText: Option[String],
comment: Option[String],
+ collation: Option[String],
isDeterministic: Option[Boolean],
containsSQL: Option[Boolean],
language: RoutineLanguage,
@@ -67,6 +68,7 @@ object CreateUserDefinedFunctionCommand {
exprText,
queryText,
comment,
+ collation,
isDeterministic,
containsSQL,
isTableFunc,
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
index 88be0e79e4e6..bfa4dd982087 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/DefaultCollationTestSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.DatasourceV2SQLBase
import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.types.{BooleanType, StringType, StructType}
abstract class DefaultCollationTestSuite extends QueryTest with
SharedSparkSession {
@@ -759,16 +759,40 @@ abstract class DefaultCollationTestSuite extends
QueryTest with SharedSparkSessi
}
}
+
abstract class DefaultCollationTestSuiteV1 extends DefaultCollationTestSuite {
+ // This is used for tests that don't depend on explicitly specifying the
data type
+ // (these tests still test the string type), or ones that are not applicable
to char/varchar
+ // types. E.g., UDFs don't support char/varchar as input parameters/return
types.
protected def stringTestNamesV1: Seq[String] = Seq(
"Check AttributeReference dataType from View with default collation",
"CTAS with DEFAULT COLLATION and VIEW",
"default string producing expressions in view definition",
+ "Test UDTF with default collation",
+ "Test UDF with default collation",
+ "Test UDTF with default collation and without columns in RETURNS TABLE",
+ "Test UDF with default collation and collation applied to return type",
+ "Test explicit UTF8_BINARY collation for UDF params/return type",
+ "ALTER SCHEMA DEFAULT COLLATION doesn't affect UDF/UDTF collation",
+ "Test applying collation to UDF params",
+ "Test UDF collation behavior with default and mixed collation settings",
+ "Test replacing UDF with default collation",
+ "Nested UDFs with default collation",
"View with UTF8_LCASE default collation from schema level"
- )
+ ) ++ schemaAndObjectCollationPairs.flatMap {
+ case (schemaDefaultCollation, udfDefaultCollation) => Seq(
+ s"""CREATE UDF/UDTF with schema level collation
+ | (schema default collation = $schemaDefaultCollation,
+ | view default collation = $udfDefaultCollation)""".stripMargin,
+ s"""CREATE OR UDF/UDTF with schema level collation
+ | (schema default collation = $schemaDefaultCollation,
+ | view default collation = $udfDefaultCollation)""".stripMargin
+ )
+ }
- testString("Check AttributeReference dataType from View with default
collation") {
+
+ testString("Check AttributeReference dataType from View with default
collation") {
_ =>
withView(testView) {
sql(s"CREATE VIEW $testView DEFAULT COLLATION UTF8_LCASE AS SELECT 'a'
AS c1")
@@ -1070,6 +1094,404 @@ abstract class DefaultCollationTestSuiteV1 extends
DefaultCollationTestSuite {
}
}
}
+ def emptyCreateTable()(f: => Unit): Unit = {
+ f
+ }
+
+ def createTable(dataType: String)(f: => Unit): Unit = {
+ withTable(testTable1) {
+ sql(
+ s"""CREATE TABLE $testTable1
+ | (c1 $dataType COLLATE UNICODE, c2 $dataType COLLATE SR_AI, c3 INT)
+ |""".stripMargin)
+ // scalastyle:off
+ sql(s"INSERT INTO $testTable1 VALUES ('a', 'a', 1)")
+ // scalastyle:on
+ f
+ }
+ }
+
+ def testUDF()(
+ createAndCheckUDF: (String, String, Boolean, String, String) => Unit):
Unit = {
+ val functionName = "f"
+ val prefix = s"${CollationFactory.CATALOG}.${CollationFactory.SCHEMA}"
+ Seq(
+ ("", "", false),
+ ("", "TEMPORARY", true),
+ ("OR REPLACE", "", false),
+ ("OR REPLACE", "TEMPORARY", true)
+ ).foreach {
+ case (replace, temporary, isTemporary) =>
+ createAndCheckUDF(replace, temporary, isTemporary, functionName,
prefix)
+ }
+ }
+
+ testString("Test UDTF with default collation") {
+ dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ createTable(dataType) {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ // Table function
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName()
+ | RETURNS TABLE
+ | (c1 $dataType COLLATE UTF8_LCASE, c2 $dataType, c3 INT, c4
$dataType)
+ | DEFAULT COLLATION UNICODE_CI
+ | RETURN
+ | SELECT *, 'w' AS c4
+ | FROM $testTable1
+ | WHERE 'a' = 'A'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT COUNT(*) FROM $functionName()"), Row(1))
+ checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
+ Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
+ Row(s"$prefix.UNICODE_CI"))
+ checkAnswer(sql(s"SELECT COLLATION(c4) FROM $functionName()"),
+ Row(s"$prefix.UNICODE_CI"))
+ checkAnswer(sql(s"SELECT c1 = 'A' FROM $functionName()"),
Row(true))
+ checkAnswer(sql(s"SELECT c2 = 'A' FROM $functionName()"),
Row(true))
+ checkAnswer(sql(s"SELECT c4 = 'W' FROM $functionName()"),
Row(true))
+ }
+ }
+ }
+ }
+
+ testString("Test UDTF with default collation and without columns in RETURNS
TABLE") { _ =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName()
+ | RETURNS TABLE
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT 'a' AS c1, 'b' COLLATE UTF8_BINARY AS c2, 'c' COLLATE
UNICODE AS c3
+ | WHERE 'a' = 'A'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b",
"c"))
+ checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
+ Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
+ Row(s"$prefix.UTF8_BINARY"))
+ checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"),
+ Row(s"$prefix.UNICODE"))
+ }
+ }
+ }
+
+ testString("Test UDF with default collation") { dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ createTable(dataType) {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName()
+ | RETURNS $dataType COLLATE UTF8_LCASE
+ | DEFAULT COLLATION UNICODE_CI
+ | RETURN
+ | SELECT c1
+ | FROM $testTable1
+ | WHERE 'a' = 'A'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+ checkAnswer(sql(s"SELECT COLLATION($functionName())"),
+ Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(true))
+ }
+ }
+ }
+ }
+
+ testString("Test UDF with default collation and collation applied to return
type") {
+ dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ createTable(dataType) {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName()
+ | RETURNS $dataType
+ | DEFAULT COLLATION UNICODE
+ | RETURN
+ | SELECT c1
+ | FROM $testTable1
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+ checkAnswer(sql(s"SELECT COLLATION($functionName())"),
+ Row(s"$prefix.UNICODE"))
+ checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(false))
+ }
+ }
+ }
+ }
+
+ testString("Test explicit UTF8_BINARY collation for UDF params/return type")
{
+ dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ emptyCreateTable() {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName
+ | (p1 $dataType COLLATE UTF8_BINARY, p2 $dataType)
+ | RETURNS $dataType COLLATE UTF8_BINARY
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT CASE WHEN p1 != 'A' AND p2 = 'B' THEN 'C' ELSE 'D'
END
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT $functionName('a', 'b') = 'C'"),
Row(true))
+ checkAnswer(sql(s"SELECT $functionName('a', 'b') = 'c'"),
Row(false))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName('b',
'c'))"),
+ Row(s"$prefix.UTF8_BINARY"))
+ }
+ }
+ }
+
+ // Table UDF
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ emptyCreateTable() {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName
+ | (p1 $dataType COLLATE UTF8_BINARY, p2 $dataType)
+ | RETURNS TABLE
+ | (c1 $dataType COLLATE UTF8_BINARY, c2 $dataType)
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT CASE WHEN p1 != 'A' AND p2 = 'B' THEN 'C' ELSE 'D'
END, 'E'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT c1 = 'C', c2 = 'E' FROM
$functionName('a', 'b')"),
+ Row(true, true))
+ checkAnswer(sql(s"SELECT c1 ='c' FROM $functionName('a', 'b')"),
Row(false))
+ checkAnswer(sql(s"SELECT c2 ='e' FROM $functionName('a', 'b')"),
Row(true))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM
$functionName('a', 'b')"),
+ Row(s"$prefix.UTF8_BINARY"))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM
$functionName('a', 'b')"),
+ Row(s"$prefix.UTF8_LCASE"))
+ }
+ }
+ }
+ }
+
+ // UDF with schema level collation tests
+ schemaAndObjectCollationPairs.foreach {
+ case (schemaDefaultCollation, udfDefaultCollation) =>
+ testString(
+ s"""CREATE UDF/UDTF with schema level collation
+ | (schema default collation = $schemaDefaultCollation,
+ | view default collation = $udfDefaultCollation)""".stripMargin) {
dataType =>
+ testCreateUDFWithSchemaLevelCollation(dataType,
schemaDefaultCollation, udfDefaultCollation)
+ }
+
+ testString(
+ s"""CREATE OR UDF/UDTF with schema level collation
+ | (schema default collation = $schemaDefaultCollation,
+ | view default collation = $udfDefaultCollation)""".stripMargin) {
dataType =>
+ testCreateUDFWithSchemaLevelCollation(dataType,
schemaDefaultCollation, udfDefaultCollation)
+ }
+ }
+
+ testString("ALTER SCHEMA DEFAULT COLLATION doesn't affect UDF/UDTF
collation") {
+ dataType =>
+ val functionName = "f"
+ val prefix = "SYSTEM.BUILTIN"
+
+ withDatabase(testSchema) {
+ sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION UTF8_LCASE")
+ sql(s"USE $testSchema")
+
+ withUserDefinedFunction((functionName, false)) {
+ sql(s"CREATE FUNCTION $functionName() RETURN SELECT 'a' WHERE 'b' =
'B'")
+
+ checkAnswer(sql(s"SELECT $functionName()"), Row("a"))
+ checkAnswer(sql(s"SELECT COLLATION($functionName())"),
Row(s"$prefix.UTF8_LCASE"))
+
+ // ALTER SCHEMA DEFAULT COLLATION
+ sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION UNICODE")
+
+ checkAnswer(sql(s"SELECT $functionName()"), Row("a"))
+ checkAnswer(sql(s"SELECT COLLATION($functionName())"),
Row(s"$prefix.UTF8_LCASE"))
+ }
+ }
+
+ withDatabase(testSchema) {
+ sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION UTF8_LCASE")
+ sql(s"USE $testSchema")
+
+ withUserDefinedFunction((functionName, false)) {
+ sql(
+ s"""CREATE FUNCTION $functionName()
+ |RETURNS TABLE (c1 $dataType, c2 $dataType COLLATE UTF8_BINARY,
+ |c3 $dataType COLLATE UNICODE)
+ |RETURN
+ |SELECT 'a', 'b', 'c' WHERE 'd' = 'D'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", "c"))
+ checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
Row(s"$prefix.UTF8_BINARY"))
+ checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"),
Row(s"$prefix.UNICODE"))
+
+ // ALTER SCHEMA DEFAULT COLLATION
+ sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION UNICODE")
+
+ checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b", "c"))
+ checkAnswer(sql(s"SELECT COLLATION(c1) FROM $functionName()"),
Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT COLLATION(c2) FROM $functionName()"),
Row(s"$prefix.UTF8_BINARY"))
+ checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName()"),
Row(s"$prefix.UNICODE"))
+ }
+ }
+ }
+
+ testString("Test applying collation to UDF params") { dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ emptyCreateTable() {
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName
+ | (p1 $dataType, p2 $dataType COLLATE UNICODE)
+ | RETURNS TABLE
+ | (c1 BOOLEAN, c2 BOOLEAN, c3 $dataType, c4 $dataType COLLATE
UNICODE,
+ | c5 $dataType COLLATE SR_AI)
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT p1 = 'A', p2 = 'A', p2, p2, p2
+ | WHERE p1 = 'A'
+ |""".stripMargin)
+
+ val expected = Seq(
+ Row(true, false, "a", "a", "a")
+ )
+ val expectedSchema = new StructType()
+ .add("c1", BooleanType)
+ .add("c2", BooleanType)
+ .add("c3", StringType)
+ .add("c4", StringType)
+ .add("c5", StringType)
+ checkAnswer(sql(s"SELECT * FROM $functionName('a', 'a')"),
+ spark.createDataFrame(spark.sparkContext.parallelize(expected),
expectedSchema))
+ checkAnswer(sql(s"SELECT COLLATION(c3) FROM $functionName('a',
'a')"),
+ Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT COLLATION(c4) FROM $functionName('a',
'a')"),
+ Row(s"$prefix.UNICODE"))
+ checkAnswer(sql(s"SELECT COLLATION(c5) FROM $functionName('a',
'a')"),
+ Row(s"$prefix.sr_AI"))
+ checkAnswer(sql(s"SELECT c3 = 'A' FROM $functionName('a', 'a')"),
+ Row(true))
+ checkAnswer(sql(s"SELECT c4 = 'A' FROM $functionName('a', 'a')"),
+ Row(false))
+ checkAnswer(sql(s"SELECT c5 = 'A' FROM $functionName('a', 'a')"),
+ Row(false))
+ }
+ }
+ }
+ }
+
+ testString("Test UDF collation behavior with default and mixed collation
settings") {
+ dataType =>
+ testUDF() {
+ (replace, temporary, isTemporary, functionName, prefix) =>
+ emptyCreateTable() {
+ val fullFunctionName =
+ if (isTemporary) {
+ functionName
+ } else {
+ s"spark_catalog.default.$functionName"
+ }
+
+ Seq(
+ // (returnsClause, returnType, otherCollation, inputChar,
compareChar)
+ ("", "UTF8_LCASE", "SR_AI", "w", "W"),
+ (s"RETURNS $dataType", "UTF8_LCASE", "SR_AI", "w", "W"),
+ // scalastyle:off
+ (s"RETURNS $dataType COLLATE SR_AI", "sr_AI", "UTF8_LCASE", "ć",
"č")
+ // scalastyle:on
+ ).foreach {
+ case (returnsClause, returnTypeCollation, otherCollation,
inputChar, equalChar) =>
+ withUserDefinedFunction((functionName, isTemporary)) {
+ sql(
+ s"""CREATE $replace $temporary FUNCTION $functionName()
$returnsClause
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT '$inputChar' AS c1
+ | WHERE 'a' = 'A'""".stripMargin)
+
+ checkAnswer(sql(s"SELECT COUNT($functionName())"), Row(1))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName())"),
+ Row(s"$prefix.$returnTypeCollation"))
+ checkAnswer(
+ sql(s"SELECT $functionName() =" +
+ s" (SELECT '$equalChar' COLLATE $returnTypeCollation)"),
+ Row(true))
+
+ val exception = intercept[AnalysisException] {
+ sql(s"SELECT $functionName() = (SELECT 'a' COLLATE
$otherCollation)")
+ }
+ assert(exception.getMessage.contains("indeterminate
collation"))
+ }
+ }
+ }
+ }
+ }
+
+ testString("Test replacing UDF with default collation") { _ =>
+ val functionName = "f"
+ val prefix = "SYSTEM.BUILTIN"
+
+ withUserDefinedFunction((functionName, false)) {
+ sql(
+ s"""CREATE FUNCTION $functionName()
+ | RETURN
+ | SELECT 'a'
+ |""".stripMargin)
+ sql(
+ s"""CREATE OR REPLACE FUNCTION $functionName()
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT 'a' AS c1
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName())"),
+ Row(s"$prefix.UTF8_LCASE"))
+ checkAnswer(sql(s"SELECT $functionName() = 'A'"), Row(true))
+ }
+ }
+
+ testString("Nested UDFs with default collation") {
+ dataType =>
+ val function1Name = "f1"
+ val function2Name = "f2"
+ withUserDefinedFunction((function1Name, false)) {
+ sql(
+ s"""CREATE FUNCTION $function1Name(s $dataType)
+ | DEFAULT COLLATION UTF8_LCASE
+ | RETURN
+ | SELECT s
+ |""".stripMargin)
+ withUserDefinedFunction((function2Name, false)) {
+ // scalastyle:off
+ sql(
+ s"""CREATE FUNCTION $function2Name()
+ | DEFAULT COLLATION SR_AI
+ | RETURN
+ | SELECT 'č'
+ | WHERE $function1Name('a') = $function1Name('A')
+ |""".stripMargin)
+ // scalastyle:on
+ checkAnswer(sql(s"SELECT COUNT($function2Name())"), Row(1))
+ }
+ }
+ }
// View with schema level collation tests
schemaAndObjectCollationPairs.foreach {
@@ -1252,6 +1674,99 @@ abstract class DefaultCollationTestSuiteV1 extends
DefaultCollationTestSuite {
}
}
}
+ private def testCreateUDFWithSchemaLevelCollation(
+ dataType: String,
+ schemaDefaultCollation: String,
+ udfDefaultCollation: Option[String],
+ replaceUDF: Boolean = false): Unit = {
+ val prefix = "SYSTEM.BUILTIN"
+ val functionName = "f"
+
+ val (udfDefaultCollationClause, resolvedDefaultCollation) =
+ if (udfDefaultCollation.isDefined) {
+ (s"DEFAULT COLLATION ${udfDefaultCollation.get}",
udfDefaultCollation.get)
+ } else {
+ ("", schemaDefaultCollation)
+ }
+ val replace = if (replaceUDF) "OR REPLACE" else ""
+
+ Seq(/* alterSchemaCollation */ false, true).foreach {
+ alterSchemaCollation =>
+ withDatabase(testSchema) {
+ if (!alterSchemaCollation) {
+ sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION
$schemaDefaultCollation")
+ } else {
+ sql(s"CREATE SCHEMA $testSchema DEFAULT COLLATION EN")
+ sql(s"ALTER SCHEMA $testSchema DEFAULT COLLATION
$schemaDefaultCollation")
+ }
+ sql(s"USE $testSchema")
+
+ Seq(
+ // (returnClause, outputCollation)
+ ("", resolvedDefaultCollation),
+ (s"RETURNS $dataType", resolvedDefaultCollation),
+ (s"RETURNS $dataType COLLATE FR", "fr")
+ ).foreach {
+ case (returnClause, outputCollation) =>
+ withUserDefinedFunction((functionName, false)) {
+ // scalastyle:off
+ sql(
+ s"""CREATE $replace FUNCTION $functionName
+ |(p1 $dataType, p2 $dataType COLLATE UTF8_BINARY, p3
$dataType COLLATE SR_AI_CI)
+ |$returnClause
+ |$udfDefaultCollationClause
+ |RETURN SELECT 'a' AS c1 WHERE p2 != 'A' AND p3 = 'Č'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT $functionName('x', 'a', 'ć')"),
Row("a"))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION($functionName('x',
'a', 'ć'))"),
+ Row(s"$prefix.$outputCollation"))
+ // scalastyle:on
+ }
+ }
+
+ withUserDefinedFunction((functionName, false)) {
+ sql(
+ s"""CREATE $replace FUNCTION $functionName()
+ |RETURNS TABLE
+ |(c1 $dataType, c2 $dataType COLLATE UTF8_BINARY, c3
$dataType COLLATE SR_AI_CI)
+ |$udfDefaultCollationClause
+ |RETURN
+ |SELECT 'a', 'b', 'c'
+ |""".stripMargin)
+
+ checkAnswer(sql(s"SELECT * FROM $functionName()"), Row("a", "b",
"c"))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM
$functionName()"),
+ Row(s"$prefix.$resolvedDefaultCollation"))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM
$functionName()"),
+ Row(s"$prefix.UTF8_BINARY"))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c3) FROM
$functionName()"),
+ Row(s"$prefix.sr_CI_AI"))
+ }
+
+ withUserDefinedFunction((functionName, false)) {
+ val pairs = defaultStringProducingExpressions.zipWithIndex.map {
+ case (expr, index) => (s"$expr AS c${index + 1}", s"c${index +
1} $dataType")
+ }
+ val columns = pairs.map(_._1).mkString(", ")
+ val returnsClause = pairs.map(_._2).mkString(", ")
+
+ sql(
+ s"""CREATE $replace FUNCTION $functionName()
+ |RETURNS TABLE
+ |($returnsClause)
+ |$udfDefaultCollationClause
+ |RETURN SELECT $columns
+ |""".stripMargin)
+
+ (1 to defaultStringProducingExpressions.length).foreach { index =>
+ checkAnswer(sql(s"SELECT COLLATION(c$index) FROM
$functionName()"),
+ Row(s"$prefix.$resolvedDefaultCollation"))
+ }
+ }
+ }
+ }
+ }
}
abstract class DefaultCollationTestSuiteV2
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
index 25d8a74797ce..56316f43f8df 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateSQLFunctionParserSuite.scala
@@ -58,6 +58,7 @@ class CreateSQLFunctionParserSuite extends AnalysisTest {
exprText = exprText,
queryText = queryText,
comment = comment,
+ collation = None,
isDeterministic = isDeterministic,
containsSQL = containsSQL,
language = LanguageSQL,
@@ -87,6 +88,7 @@ class CreateSQLFunctionParserSuite extends AnalysisTest {
exprText = exprText,
queryText = queryText,
comment = comment,
+ collation = None,
isDeterministic = isDeterministic,
containsSQL = containsSQL,
isTableFunc = isTableFunc,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]