This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b05ef451061f [SPARK-50175][SQL] Change collation precedence calculation
b05ef451061f is described below
commit b05ef451061f9ff7ba15bf037199339ed7236748
Author: Stefan Kandic <[email protected]>
AuthorDate: Thu Nov 21 17:38:57 2024 +0800
[SPARK-50175][SQL] Change collation precedence calculation
### What changes were proposed in this pull request?
Changing the way how the collation strength of string expressions are
calculated. Currently, there are three different collation strengths:
- explicit - result of the `collate` expression
- implicit - column references and output of string functions
- default - literals and cast expression
However, unlike in other database systems (pg, sql server) collation
strengths were not propagated up the expression tree, meaning that
`substring('a' collate unicode), 0, 1)` would have implicit strength because it
is the result of a string expression.
My proposal is to change the behavior to be more in line with other systems
mentioned above; and to do it by traversing the expression tree and propagating
the highest precedence strengths up (explicit being the highest and default the
lowest) while also finding the conflicts between them (conflicting explicit or
implicit strenghts).
### Why are the changes needed?
To be more consistent with other systems that have collations (postgre, sql
server etc.)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48663 from stefankandic/newCollationPrec-separate.
Authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/analysis/CollationTypeCoercion.scala | 347 +++++++++++++-------
.../catalyst/expressions/complexTypeCreator.scala | 7 +-
.../expressions/complexTypeExtractors.scala | 1 +
.../catalyst/expressions/stringExpressions.scala | 8 +-
.../spark/sql/CollationExpressionWalkerSuite.scala | 2 +-
.../spark/sql/CollationSQLExpressionsSuite.scala | 11 +-
.../sql/CollationStringExpressionsSuite.scala | 6 +-
.../org/apache/spark/sql/CollationSuite.scala | 64 ++--
.../collation/CollationTypePrecedenceSuite.scala | 361 ++++++++++++++++++++-
9 files changed, 630 insertions(+), 177 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
index 1e9c3aabedb3..532e5e0d0a06 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
@@ -17,95 +17,51 @@
package org.apache.spark.sql.catalyst.analysis
-import javax.annotation.Nullable
-
import scala.annotation.tailrec
+import org.apache.spark.sql.catalyst.analysis.CollationStrength.{Default,
Explicit, Implicit}
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType,
haveSameType}
-import org.apache.spark.sql.catalyst.expressions.{
- ArrayAppend,
- ArrayContains,
- ArrayExcept,
- ArrayIntersect,
- ArrayJoin,
- ArrayPosition,
- ArrayRemove,
- ArraysOverlap,
- ArrayUnion,
- CaseWhen,
- Cast,
- Coalesce,
- Collate,
- Concat,
- ConcatWs,
- Contains,
- CreateArray,
- CreateMap,
- Elt,
- EndsWith,
- EqualNullSafe,
- EqualTo,
- Expression,
- FindInSet,
- GetMapValue,
- GreaterThan,
- GreaterThanOrEqual,
- Greatest,
- If,
- In,
- InSubquery,
- Lag,
- Lead,
- Least,
- LessThan,
- LessThanOrEqual,
- Levenshtein,
- Literal,
- Mask,
- Overlay,
- RaiseError,
- RegExpReplace,
- SplitPart,
- StartsWith,
- StringInstr,
- StringLocate,
- StringLPad,
- StringReplace,
- StringRPad,
- StringSplitSQL,
- StringToMap,
- StringTranslate,
- StringTrim,
- StringTrimLeft,
- StringTrimRight,
- SubstringIndex,
- ToNumber,
- TryToNumber
-}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}
+import org.apache.spark.sql.util.SchemaUtils
/**
* Type coercion helper that matches against expressions in order to apply
collation type coercion.
*/
object CollationTypeCoercion {
+
+ private val COLLATION_CONTEXT_TAG = new
TreeNodeTag[CollationContext]("collationContext")
+
+ private def hasCollationContextTag(expr: Expression): Boolean = {
+ expr.getTagValue(COLLATION_CONTEXT_TAG).isDefined
+ }
+
def apply(expression: Expression): Expression = expression match {
+ case cast: Cast if shouldRemoveCast(cast) =>
+ cast.child
+
case ifExpr: If =>
ifExpr.withNewChildren(
ifExpr.predicate +: collateToSingleType(Seq(ifExpr.trueValue,
ifExpr.falseValue))
)
case caseWhenExpr: CaseWhen if
!haveSameType(caseWhenExpr.inputTypesForMerging) =>
- val outputStringType =
- getOutputCollation(caseWhenExpr.branches.map(_._2) ++
caseWhenExpr.elseValue)
- val newBranches = caseWhenExpr.branches.map {
- case (condition, value) =>
- (condition, castStringType(value, outputStringType).getOrElse(value))
+ val outputStringType = findLeastCommonStringType(
+ caseWhenExpr.branches.map(_._2) ++ caseWhenExpr.elseValue)
+ outputStringType match {
+ case Some(st) =>
+ val newBranches = caseWhenExpr.branches.map { case (condition,
value) =>
+ (condition, castStringType(value, st))
+ }
+ val newElseValue =
+ caseWhenExpr.elseValue.map(e => castStringType(e, st))
+ CaseWhen(newBranches, newElseValue)
+
+ case _ =>
+ caseWhenExpr
}
- val newElseValue =
- caseWhenExpr.elseValue.map(e => castStringType(e,
outputStringType).getOrElse(e))
- CaseWhen(newBranches, newElseValue)
case stringLocate: StringLocate =>
stringLocate.withNewChildren(
@@ -156,6 +112,12 @@ object CollationTypeCoercion {
val newValues = collateToSingleType(mapCreate.values)
mapCreate.withNewChildren(newKeys.zip(newValues).flatMap(pair =>
Seq(pair._1, pair._2)))
+ case namedStruct: CreateNamedStruct if namedStruct.children.size % 2 == 0
=>
+ val newNames = collateToSingleType(namedStruct.nameExprs)
+ val newValues = collateToSingleType(namedStruct.valExprs)
+ val interleaved = newNames.zip(newValues).flatMap(pair => Seq(pair._1,
pair._2))
+ namedStruct.withNewChildren(interleaved)
+
case splitPart: SplitPart =>
val Seq(str, delimiter, partNum) = splitPart.children
val Seq(newStr, newDelimiter) = collateToSingleType(Seq(str, delimiter))
@@ -193,88 +155,221 @@ object CollationTypeCoercion {
case other => other
}
+ /**
+ * If childType is collated and target is UTF8_BINARY, the collation of the
output
+ * should be that of the childType.
+ */
+ private def shouldRemoveCast(cast: Cast): Boolean = {
+ val isUserDefined = cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined
+ val isChildTypeCollatedString = cast.child.dataType match {
+ case st: StringType => !st.isUTF8BinaryCollation
+ case _ => false
+ }
+ val targetType = cast.dataType
+
+ isUserDefined && isChildTypeCollatedString && targetType == StringType
+ }
+
/**
* Extracts StringTypes from filtered hasStringType
*/
@tailrec
- private def extractStringType(dt: DataType): StringType = dt match {
- case st: StringType => st
+ private def extractStringType(dt: DataType): Option[StringType] = dt match {
+ case st: StringType => Some(st)
case ArrayType(et, _) => extractStringType(et)
+ case _ => None
}
/**
* Casts given expression to collated StringType with id equal to
collationId only
* if expression has StringType in the first place.
- * @param expr
- * @param collationId
- * @return
*/
- def castStringType(expr: Expression, st: StringType): Option[Expression] =
- castStringType(expr.dataType, st).map { dt => Cast(expr, dt)}
+ def castStringType(expr: Expression, st: StringType): Expression = {
+ castStringType(expr.dataType, st)
+ .map(dt => Cast(expr, dt))
+ .getOrElse(expr)
+ }
private def castStringType(inType: DataType, castType: StringType):
Option[DataType] = {
- @Nullable val ret: DataType = inType match {
- case st: StringType if st.collationId != castType.collationId => castType
+ inType match {
+ case st: StringType if st.collationId != castType.collationId =>
+ Some(castType)
case ArrayType(arrType, nullable) =>
- castStringType(arrType, castType).map(ArrayType(_, nullable)).orNull
- case _ => null
+ castStringType(arrType, castType).map(ArrayType(_, nullable))
+ case _ => None
}
- Option(ret)
}
/**
* Collates input expressions to a single collation.
*/
- def collateToSingleType(exprs: Seq[Expression]): Seq[Expression] = {
- val st = getOutputCollation(exprs)
+ def collateToSingleType(expressions: Seq[Expression]): Seq[Expression] = {
+ val lctOpt = findLeastCommonStringType(expressions)
- exprs.map(e => castStringType(e, st).getOrElse(e))
+ lctOpt match {
+ case Some(lct) =>
+ expressions.map(e => castStringType(e, lct))
+ case _ =>
+ expressions
+ }
}
/**
- * Based on the data types of the input expressions this method determines
- * a collation type which the output will have. This function accepts Seq of
- * any expressions, but will only be affected by collated StringTypes or
- * complex DataTypes with collated StringTypes (e.g. ArrayType)
+ * Tries to find the least common StringType among the given expressions.
*/
- def getOutputCollation(expr: Seq[Expression]): StringType = {
- val explicitTypes = expr.filter {
- case _: Collate => true
- case _ => false
- }
- .map(_.dataType.asInstanceOf[StringType].collationId)
- .distinct
-
- explicitTypes.size match {
- // We have 1 explicit collation
- case 1 => StringType(explicitTypes.head)
- // Multiple explicit collations occurred
- case size if size > 1 =>
- throw QueryCompilationErrors
- .explicitCollationMismatchError(
- explicitTypes.map(t => StringType(t))
- )
- // Only implicit or default collations present
- case 0 =>
- val implicitTypes = expr.filter {
- case Literal(_, _: StringType) => false
- case cast: Cast if
cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
- cast.child.dataType.isInstanceOf[StringType]
- case _ => true
- }
- .map(_.dataType)
- .filter(hasStringType)
- .map(extractStringType(_).collationId)
- .distinct
-
- if (implicitTypes.length > 1) {
- throw QueryCompilationErrors.implicitCollationMismatchError(
- implicitTypes.map(t => StringType(t))
- )
+ private def findLeastCommonStringType(expressions: Seq[Expression]):
Option[StringType] = {
+ if (!expressions.exists(e =>
SchemaUtils.hasNonUTF8BinaryCollation(e.dataType))) {
+ return None
+ }
+
+ val collationContextWinner =
expressions.foldLeft(findCollationContext(expressions.head)) {
+ case (Some(left), right) =>
+ findCollationContext(right).flatMap { ctx =>
+ collationPrecedenceWinner(left, ctx)
}
- else {
-
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
+ case (None, _) => return None
+ }
+
+ collationContextWinner.flatMap { cc =>
+ extractStringType(cc.dataType)
+ }
+ }
+
+ /**
+ * Tries to find the collation context for the given expression.
+ * If found, it will also set the [[COLLATION_CONTEXT_TAG]] on the
expression,
+ * so that the context can be reused later.
+ */
+ private def findCollationContext(expr: Expression): Option[CollationContext]
= {
+ val contextOpt = expr match {
+ case _ if hasCollationContextTag(expr) =>
+ Some(expr.getTagValue(COLLATION_CONTEXT_TAG).get)
+
+ // if `expr` doesn't have a string in its dataType then it doesn't
+ // have the collation context either
+ case _ if !expr.dataType.existsRecursively(_.isInstanceOf[StringType]) =>
+ None
+
+ case collate: Collate =>
+ Some(CollationContext(collate.dataType, Explicit))
+
+ case _: Alias | _: SubqueryExpression | _: AttributeReference | _:
VariableReference =>
+ Some(CollationContext(expr.dataType, Implicit))
+
+ case _: Literal =>
+ Some(CollationContext(expr.dataType, Default))
+
+ // if it does have a string type but none of its children do
+ // then the collation context strength is default
+ case _ if
!expr.children.exists(_.dataType.existsRecursively(_.isInstanceOf[StringType]))
=>
+ Some(CollationContext(expr.dataType, Default))
+
+ case _ =>
+ val contextWinnerOpt = getContextRelevantChildren(expr)
+ .flatMap(findCollationContext)
+ .foldLeft(Option.empty[CollationContext]) {
+ case (Some(left), right) =>
+ collationPrecedenceWinner(left, right)
+ case (None, right) =>
+ Some(right)
+ }
+
+ contextWinnerOpt.map { context =>
+ if (hasStringType(expr.dataType)) {
+ CollationContext(expr.dataType, context.strength)
+ } else {
+ context
+ }
}
}
+
+ contextOpt.foreach(expr.setTagValue(COLLATION_CONTEXT_TAG, _))
+ contextOpt
+ }
+
+ /**
+ * Returns the children of the given expression that should be used for
calculating the
+ * winning collation context.
+ */
+ private def getContextRelevantChildren(expression: Expression):
Seq[Expression] = {
+ expression match {
+ // collation context for named struct should be calculated based on its
values only
+ case createStruct: CreateNamedStruct =>
+ createStruct.valExprs
+
+ // collation context does not depend on the key for extracting the value
+ case extract: ExtractValue =>
+ Seq(extract.child)
+
+ // we currently don't support collation precedence for maps,
+ // as this would involve calculating them for keys and values separately
+ case _: CreateMap =>
+ Seq.empty
+
+ case _ =>
+ expression.children
+ }
+ }
+
+ /**
+ * Returns the collation context that wins in precedence between left and
right.
+ */
+ private def collationPrecedenceWinner(
+ left: CollationContext,
+ right: CollationContext): Option[CollationContext] = {
+
+ val (leftStringType, rightStringType) =
+ (extractStringType(left.dataType), extractStringType(right.dataType))
match {
+ case (Some(l), Some(r)) =>
+ (l, r)
+ case (None, None) =>
+ return None
+ case (Some(_), None) =>
+ return Some(left)
+ case (None, Some(_)) =>
+ return Some(right)
+ }
+
+ (left.strength, right.strength) match {
+ case (Explicit, Explicit) if leftStringType != rightStringType =>
+ throw QueryCompilationErrors.explicitCollationMismatchError(
+ Seq(leftStringType, rightStringType))
+
+ case (Explicit, _) => Some(left)
+ case (_, Explicit) => Some(right)
+
+ case (Implicit, Implicit) if leftStringType != rightStringType =>
+ throw QueryCompilationErrors.implicitCollationMismatchError(
+ Seq(leftStringType, rightStringType))
+
+ case (Implicit, _) => Some(left)
+ case (_, Implicit) => Some(right)
+
+ case (Default, Default) if leftStringType != rightStringType =>
+ throw QueryCompilationErrors.implicitCollationMismatchError(
+ Seq(leftStringType, rightStringType))
+
+ case _ =>
+ Some(left)
+ }
}
}
+
+/**
+ * Represents the strength of collation used for determining precedence in
collation resolution.
+ */
+private sealed trait CollationStrength {}
+
+ private object CollationStrength {
+ case object Explicit extends CollationStrength {}
+ case object Implicit extends CollationStrength {}
+ case object Default extends CollationStrength {}
+}
+
+/**
+ * Encapsulates the context for collation, including data type and strength.
+ *
+ * @param dataType The data type associated with this collation context.
+ * @param strength The strength level of the collation, which determines its
precedence.
+ */
+private case class CollationContext(dataType: DataType, strength:
CollationStrength) {}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 2098ee274dfe..e7cc174f7cf3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -565,11 +565,14 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
extends TernaryExpression with ExpectsInputTypes {
override def nullIntolerant: Boolean = true
def this(child: Expression, pairDelim: Expression) = {
- this(child, pairDelim, Literal(":"))
+ this(child, pairDelim, Literal.create(":", SQLConf.get.defaultStringType))
}
def this(child: Expression) = {
- this(child, Literal(","), Literal(":"))
+ this(
+ child,
+ Literal.create(",", SQLConf.get.defaultStringType),
+ Literal.create(":", SQLConf.get.defaultStringType))
}
override def stateful: Boolean = true
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 3b8d4e09905e..2013cd8d6e63 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -93,6 +93,7 @@ object ExtractValue {
trait ExtractValue extends Expression {
override def nullIntolerant: Boolean = true
final override val nodePatterns: Seq[TreePattern] = Seq(EXTRACT_VALUE)
+ val child: Expression
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index d92f45b1968a..c97920619ba4 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -1862,7 +1862,8 @@ trait PadExpressionBuilderBase extends ExpressionBuilder {
if (expressions(0).dataType == BinaryType && behaviorChangeEnabled) {
BinaryPad(funcName, expressions(0), expressions(1),
Literal(Array[Byte](0)))
} else {
- createStringPad(expressions(0), expressions(1), Literal(" "))
+ createStringPad(expressions(0),
+ expressions(1), Literal.create(" ", SQLConf.get.defaultStringType))
}
} else if (numArgs == 3) {
if (expressions(0).dataType == BinaryType && expressions(2).dataType ==
BinaryType
@@ -1992,7 +1993,10 @@ object RPadExpressionBuilder extends
PadExpressionBuilderBase {
}
}
-case class StringRPad(str: Expression, len: Expression, pad: Expression =
Literal(" "))
+case class StringRPad(
+ str: Expression,
+ len: Expression,
+ pad: Expression = Literal.create(" ", SQLConf.get.defaultStringType))
extends TernaryExpression with ImplicitCastInputTypes {
override def nullIntolerant: Boolean = true
override def first: Expression = str
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
index 2b49b76ff8c7..bc62fa5fdd33 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala
@@ -742,7 +742,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite
with SharedSparkSessi
}
} catch {
case e: SparkRuntimeException => assert(e.getCondition ==
"USER_RAISED_EXCEPTION")
- case other: Throwable => throw other
+ case other: Throwable => throw new Exception(s"Query $query failed",
other)
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 3563e04dced1..6feb4587b816 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -2416,8 +2416,15 @@ class CollationSQLExpressionsSuite
|collate('${testCase.left}', '${testCase.leftCollation}'))=
|collate('${testCase.right}', '${testCase.rightCollation}');
|""".stripMargin
- val testQuery = sql(query)
- checkAnswer(testQuery, Row(testCase.result))
+
+ if (testCase.leftCollation == testCase.rightCollation) {
+ checkAnswer(sql(query), Row(testCase.result))
+ } else {
+ val exception = intercept[AnalysisException] {
+ sql(query)
+ }
+ assert(exception.getCondition === "COLLATION_MISMATCH.EXPLICIT")
+ }
})
val queryPass =
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 9ee2cfb964fe..2a0b84c07507 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -180,11 +180,12 @@ class CollationStringExpressionsSuite
// Because `StringSplitSQL` is an internal expression,
// E2E SQL test cannot be performed in `collations.sql`.
+
checkError(
exception = intercept[AnalysisException] {
val expr = StringSplitSQL(
- Cast(Literal.create("1a2"), StringType("UTF8_BINARY")),
- Cast(Literal.create("a"), StringType("UTF8_LCASE")))
+ Literal.create("1a2", StringType("UTF8_BINARY")),
+ Literal.create("a", StringType("UTF8_LCASE")))
CollationTypeCasts.transform(expr)
},
condition = "COLLATION_MISMATCH.IMPLICIT",
@@ -193,6 +194,7 @@ class CollationStringExpressionsSuite
"implicitTypes" -> """"STRING", "STRING COLLATE UTF8_LCASE""""
)
)
+
checkError(
exception = intercept[AnalysisException] {
val expr = StringSplitSQL(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index f5cb30809ae5..170782005383 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -624,17 +624,11 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 = COLLATE('a',
'UTF8_BINARY')"),
Seq(Row("a")))
- // fail with implicit mismatch, as function return should be considered
implicit
- checkError(
- exception = intercept[AnalysisException] {
- sql(s"SELECT c1 FROM $tableName " +
- s"WHERE c1 = SUBSTR(COLLATE('a', 'UNICODE'), 0)")
- },
- condition = "COLLATION_MISMATCH.IMPLICIT",
- parameters = Map(
- "implicitTypes" -> """"STRING COLLATE UTF8_LCASE", "STRING COLLATE
UNICODE""""
- )
- )
+ // explicit collation propagates up
+ checkAnswer(
+ sql(s"SELECT c1 FROM $tableName " +
+ s"WHERE c1 = SUBSTR(COLLATE('a', 'UNICODE'), 0)"),
+ Row("a"))
// in operator
checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 IN ('a')"),
@@ -742,9 +736,16 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
)
// concat + in
- checkAnswer(sql(s"SELECT c1 FROM $tableName WHERE c1 || COLLATE('a',
'UTF8_BINARY') IN " +
- s"(COLLATE('aa', 'UNICODE'))"),
- Seq(Row("a")))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"SELECT c1 FROM $tableName where c1 || COLLATE('a',
'UTF8_BINARY') IN " +
+ s"(COLLATE('aa', 'UNICODE'))")
+ },
+ condition = "COLLATION_MISMATCH.EXPLICIT",
+ parameters = Map(
+ "explicitTypes" -> """"STRING", "STRING COLLATE UNICODE""""
+ )
+ )
// array creation supports implicit casting
checkAnswer(sql(s"SELECT typeof(array('a' COLLATE UNICODE, 'b')[1])"),
@@ -765,14 +766,21 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
exception = intercept[AnalysisException] {
sql(s"SELECT array('A', 'a' COLLATE UNICODE) == array('b' COLLATE
UNICODE_CI)")
},
- condition = "COLLATION_MISMATCH.IMPLICIT",
+ condition = "COLLATION_MISMATCH.EXPLICIT",
parameters = Map(
- "implicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE
UNICODE_CI""""
+ "explicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE
UNICODE_CI""""
)
)
- checkAnswer(sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c'
collate UNICODE_CI)"),
- Seq(Row("acb")))
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("SELECT array_join(array('a', 'b' collate UNICODE), 'c' collate
UNICODE_CI)")
+ },
+ condition = "COLLATION_MISMATCH.EXPLICIT",
+ parameters = Map(
+ "explicitTypes" -> """"STRING COLLATE UNICODE", "STRING COLLATE
UNICODE_CI""""
+ )
+ )
}
}
@@ -851,26 +859,6 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
}
- test("SPARK-47692: Parameter markers with variable mapping") {
- checkAnswer(
- spark.sql(
- "SELECT collation(:var1 || :var2)",
- Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
- "var2" -> Literal.create('b', StringType("UNICODE")))),
- Seq(Row("UTF8_BINARY"))
- )
-
- withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
- checkAnswer(
- spark.sql(
- "SELECT collation(:var1 || :var2)",
- Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
- "var2" -> Literal.create('b', StringType("UNICODE")))),
- Seq(Row("UNICODE"))
- )
- }
- }
-
test("SPARK-47210: Cast of default collated strings in IN expression") {
val tableName = "t1"
withTable(tableName) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
index 6f10acf264b0..4a904a85e0a7 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationTypePrecedenceSuite.scala
@@ -18,11 +18,10 @@
package org.apache.spark.sql.collation
import org.apache.spark.SparkThrowable
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.connector.DatasourceV2SQLBase
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.test.SharedSparkSession
-class CollationTypePrecedenceSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
+class CollationTypePrecedenceSuite extends QueryTest with SharedSparkSession {
val dataSource: String = "parquet"
@@ -33,6 +32,360 @@ class CollationTypePrecedenceSuite extends
DatasourceV2SQLBase with AdaptiveSpar
assert(exception.getCondition === errorClass)
}
+ private def assertExplicitMismatch(df: => DataFrame): Unit =
+ assertThrowsError(df, "COLLATION_MISMATCH.EXPLICIT")
+
+ private def assertImplicitMismatch(df: => DataFrame): Unit =
+ assertThrowsError(df, "COLLATION_MISMATCH.IMPLICIT")
+
+ test("explicit collation propagates up") {
+ checkAnswer(
+ sql(s"SELECT COLLATION('a' collate unicode)"),
+ Row("UNICODE"))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION('a' collate unicode || 'b')"),
+ Row("UNICODE"))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(SUBSTRING('a' collate unicode, 0, 1))"),
+ Row("UNICODE"))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(SUBSTRING('a' collate unicode, 0, 1) || 'b')"),
+ Row("UNICODE"))
+
+ assertExplicitMismatch(
+ sql(s"SELECT COLLATION('a' collate unicode || 'b' collate utf8_lcase)"))
+
+ assertExplicitMismatch(
+ sql(s"""
+ |SELECT COLLATION(
+ | SUBSTRING('a' collate unicode, 0, 1) ||
+ | SUBSTRING('b' collate utf8_lcase, 0, 1))
+ |""".stripMargin))
+ }
+
+ test("implicit collation in columns") {
+ val tableName = "implicit_coll_tbl"
+ val c1Collation = "UNICODE"
+ val c2Collation = "UNICODE_CI"
+ val structCollation = "UTF8_LCASE"
+ withTable(tableName) {
+ sql(s"""
+ |CREATE TABLE $tableName (
+ | c1 STRING COLLATE $c1Collation,
+ | c2 STRING COLLATE $c2Collation,
+ | c3 STRUCT<col1: STRING COLLATE $structCollation>)
+ |""".stripMargin)
+ sql(s"INSERT INTO $tableName VALUES ('a', 'b', named_struct('col1',
'c'))")
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || 'a') FROM $tableName"),
+ Seq(Row(c1Collation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c3.col1 || 'a') FROM $tableName"),
+ Seq(Row(structCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(SUBSTRING(c1, 0, 1) || 'a') FROM $tableName"),
+ Seq(Row(c1Collation)))
+
+ assertImplicitMismatch(sql(s"SELECT COLLATION(c1 || c2) FROM
$tableName"))
+ assertImplicitMismatch(sql(s"SELECT COLLATION(c1 || c3.col1) FROM
$tableName"))
+ assertImplicitMismatch(
+ sql(s"SELECT COLLATION(SUBSTRING(c1, 0, 1) || c2) FROM $tableName"))
+ }
+ }
+
+ test("variables have implicit collation") {
+ val v1Collation = "UTF8_BINARY"
+ val v2Collation = "UTF8_LCASE"
+ sql(s"DECLARE v1 = 'a'")
+ sql(s"DECLARE v2 = 'b' collate $v2Collation")
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(v1 || 'a')"),
+ Row(v1Collation))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(v2 || 'a')"),
+ Row(v2Collation))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(v2 || 'a' COLLATE UTF8_BINARY)"),
+ Row("UTF8_BINARY"))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(SUBSTRING(v2, 0, 1) || 'a')"),
+ Row(v2Collation))
+
+ assertImplicitMismatch(sql(s"SELECT COLLATION(v1 || v2)"))
+ assertImplicitMismatch(sql(s"SELECT COLLATION(SUBSTRING(v1, 0, 1) || v2)"))
+ }
+
+ test("subqueries have implicit collation strength") {
+ withTable("t") {
+ sql(s"CREATE TABLE t (c STRING COLLATE UTF8_LCASE) USING $dataSource")
+
+ sql(s"SELECT (SELECT 'text' COLLATE UTF8_BINARY) || c collate
UTF8_BINARY from t")
+ assertImplicitMismatch(
+ sql(s"SELECT (SELECT 'text' COLLATE UTF8_BINARY) || c from t"))
+ }
+
+ // Simple subquery with explicit collation
+ checkAnswer(
+ sql(s"SELECT COLLATION((SELECT 'text' COLLATE UTF8_BINARY) ||
'suffix')"),
+ Row("UTF8_BINARY")
+ )
+
+ checkAnswer(
+ sql(s"SELECT COLLATION((SELECT 'text' COLLATE UTF8_LCASE) || 'suffix')"),
+ Row("UTF8_LCASE")
+ )
+
+ // Nested subquery should retain the collation of the deepest expression
+ checkAnswer(
+ sql(s"SELECT COLLATION((SELECT (SELECT 'inner' COLLATE UTF8_LCASE) ||
'outer'))"),
+ Row("UTF8_LCASE")
+ )
+
+ checkAnswer(
+ sql(s"SELECT COLLATION((SELECT (SELECT 'inner' COLLATE UTF8_BINARY) ||
'outer'))"),
+ Row("UTF8_BINARY")
+ )
+
+ // Subqueries with mixed collations should follow collation precedence
rules
+ checkAnswer(
+ sql(s"SELECT COLLATION((SELECT 'string1' COLLATE UTF8_LCASE || " +
+ s"(SELECT 'string2' COLLATE UTF8_BINARY)))"),
+ Row("UTF8_LCASE")
+ )
+ }
+
+ test("struct test") {
+ val tableName = "struct_tbl"
+ val c1Collation = "UNICODE_CI"
+ val c2Collation = "UNICODE"
+ withTable(tableName) {
+ sql(s"""
+ |CREATE TABLE $tableName (
+ | c1 STRUCT<col1: STRING COLLATE $c1Collation>,
+ | c2 STRUCT<col1: STRUCT<col1: STRING COLLATE $c2Collation>>)
+ |USING $dataSource
+ |""".stripMargin)
+ sql(s"INSERT INTO $tableName VALUES (named_struct('col1', 'a')," +
+ s"named_struct('col1', named_struct('col1', 'c')))")
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c2.col1.col1 || 'a') FROM $tableName"),
+ Seq(Row(c2Collation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1.col1 || 'a') FROM $tableName"),
+ Seq(Row(c1Collation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1.col1 || 'a' collate UNICODE) FROM
$tableName"),
+ Seq(Row("UNICODE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(struct('a').col1 || 'a' collate UNICODE) FROM
$tableName"),
+ Seq(Row("UNICODE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(struct('a' collate UNICODE).col1 || 'a') FROM
$tableName"),
+ Seq(Row("UNICODE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(struct('a').col1 collate UNICODE || 'a' collate
UNICODE) " +
+ s"FROM $tableName"),
+ Seq(Row("UNICODE")))
+
+ assertExplicitMismatch(
+ sql(s"SELECT COLLATION(struct('a').col1 collate UNICODE || 'a' collate
UTF8_LCASE) " +
+ s"FROM $tableName"))
+
+ assertExplicitMismatch(
+ sql(s"SELECT COLLATION(struct('a' collate UNICODE).col1 || 'a' collate
UTF8_LCASE) " +
+ s"FROM $tableName"))
+ }
+ }
+
+ test("array test") {
+ val tableName = "array_tbl"
+ val columnCollation = "UNICODE"
+ val arrayCollation = "UNICODE_CI"
+ withTable(tableName) {
+ sql(s"""
+ |CREATE TABLE $tableName (
+ | c1 STRING COLLATE $columnCollation,
+ | c2 ARRAY<STRING COLLATE $arrayCollation>)
+ |USING $dataSource
+ |""".stripMargin)
+
+ sql(s"INSERT INTO $tableName VALUES ('a', array('b', 'c'))")
+
+ checkAnswer(
+ sql(s"SELECT collation(element_at(array('a', 'b' collate utf8_lcase),
1))"),
+ Seq(Row("UTF8_LCASE")))
+
+ assertExplicitMismatch(
+ sql(s"SELECT collation(element_at(array('a' collate unicode, 'b'
collate utf8_lcase), 1))")
+ )
+
+ checkAnswer(
+ sql(s"SELECT collation(element_at(array('a', 'b' collate utf8_lcase),
1) || c1)" +
+ s"from $tableName"),
+ Seq(Row("UTF8_LCASE")))
+
+ checkAnswer(
+ sql(s"SELECT collation(element_at(array_append(c2, 'd'), 1)) FROM
$tableName"),
+ Seq(Row(arrayCollation))
+ )
+
+ checkAnswer(
+ sql(s"SELECT collation(element_at(array_append(c2, 'd' collate
utf8_lcase), 1))" +
+ s"FROM $tableName"),
+ Seq(Row("UTF8_LCASE"))
+ )
+ }
+ }
+
+ test("array cast") {
+ val tableName = "array_cast_tbl"
+ val columnCollation = "UNICODE"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (c1 ARRAY<STRING COLLATE
$columnCollation>) USING $dataSource")
+ sql(s"INSERT INTO $tableName VALUES (array('a'))")
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1[0]) FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(cast(c1 AS ARRAY<STRING>)[0]) FROM $tableName"),
+ Seq(Row("UTF8_BINARY")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(cast(c1 AS ARRAY<STRING COLLATE
UTF8_LCASE>)[0]) FROM $tableName"),
+ Seq(Row("UTF8_LCASE")))
+ }
+ }
+
+ test("user defined cast") {
+ val tableName = "dflt_coll_tbl"
+ val columnCollation = "UNICODE"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (c1 STRING COLLATE $columnCollation) USING
$dataSource")
+ sql(s"INSERT INTO $tableName VALUES ('a')")
+
+ // only for non string inputs cast results in default collation
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || CAST(to_char(DATE'2016-04-08', 'y') AS
STRING)) " +
+ s"FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(CAST(to_char(DATE'2016-04-08', 'y') AS STRING))
" +
+ s"FROM $tableName"),
+ Seq(Row("UTF8_BINARY")))
+
+ // for string inputs collation is of the child expression
+ checkAnswer(
+ sql(s"SELECT COLLATION(CAST('a' AS STRING)) FROM $tableName"),
+ Seq(Row("UTF8_BINARY")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(CAST(c1 AS STRING)) FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(CAST(c1 collate UTF8_LCASE AS STRING)) FROM
$tableName"),
+ Seq(Row("UTF8_LCASE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || CAST('a' AS STRING)) FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || CAST('a' collate UTF8_LCASE AS STRING))
FROM $tableName"),
+ Seq(Row("UTF8_LCASE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || CAST(c1 AS STRING)) FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || SUBSTRING(CAST(c1 AS STRING), 0, 1)) FROM
$tableName"),
+ Seq(Row(columnCollation)))
+ }
+ }
+
+ test("str fns without params have default strength") {
+ val tableName = "str_fns_tbl"
+ val columnCollation = "UNICODE"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (c1 STRING COLLATE $columnCollation) USING
$dataSource")
+ sql(s"INSERT INTO $tableName VALUES ('a')")
+
+ checkAnswer(
+ sql(s"SELECT COLLATION('a' collate utf8_lcase || current_database())
FROM $tableName"),
+ Seq(Row("UTF8_LCASE")))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(c1 || current_database()) FROM $tableName"),
+ Seq(Row(columnCollation)))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION('a' || current_database()) FROM $tableName"),
+ Seq(Row("UTF8_BINARY")))
+ }
+ }
+
+ test("functions that contain both string and non string params") {
+ checkAnswer(
+ sql(s"SELECT COLLATION(elt(2, 'a', 'b'))"),
+ Row("UTF8_BINARY"))
+
+ checkAnswer(
+ sql(s"SELECT COLLATION(elt(2, 'a' collate UTF8_LCASE, 'b'))"),
+ Row("UTF8_LCASE"))
+
+ assertExplicitMismatch(
+ sql(s"SELECT COLLATION(elt(2, 'a' collate UTF8_LCASE, 'b' collate
UNICODE))"))
+ }
+
+ test("named_struct names and values") {
+ checkAnswer(
+ sql(s"SELECT named_struct('name1', 'value1', 'name2', 'value2')"),
+ Row(Row("value1", "value2")))
+
+ checkAnswer(
+ sql(s"SELECT named_struct" +
+ s"('name1' collate unicode, 'value1', 'name2' collate unicode,
'value2')"),
+ Row(Row("value1", "value2")))
+
+ checkAnswer(
+ sql(s"SELECT named_struct" +
+ s"('name1', 'value1' collate unicode, 'name2', 'value2' collate
unicode)"),
+ Row(Row("value1", "value2")))
+
+ checkAnswer(
+ sql(s"SELECT named_struct('name1' collate utf8_lcase, 'value1' collate
unicode," +
+ s"'name2' collate utf8_lcase, 'value2' collate unicode)"),
+ Row(Row("value1", "value2")))
+
+ assertExplicitMismatch(
+ sql(s"SELECT named_struct" +
+ s"('name1' collate unicode, 'value1', 'name2' collate utf8_lcase,
'value2')"))
+
+ assertExplicitMismatch(
+ sql(s"SELECT named_struct" +
+ s"('name1', 'value1' collate unicode, 'name2', 'value2' collate
utf8_lcase)"))
+ }
+
test("access collated map via literal") {
val tableName = "map_with_lit"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]