This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31036274fc1c [SPARK-47410][SQL] Refactor UTF8String and
CollationFactory
31036274fc1c is described below
commit 31036274fc1c8013c6428735659959f46afea5d8
Author: Uros Bojanic <[email protected]>
AuthorDate: Thu Apr 11 22:21:21 2024 +0800
[SPARK-47410][SQL] Refactor UTF8String and CollationFactory
### What changes were proposed in this pull request?
This PR introduces comprehensive support for collation-aware expressions in
Spark, focusing on improving code structure, clarity, and testing coverage for
various expressions (including: Contains, StartsWith, EndsWith).
### Why are the changes needed?
The changes are essential to improve the maintainability and readability of
collation-related code in Spark expressions. By restructuring and centralizing
collation support through, we simplify the addition of new collation-aware
operations and ensure consistent testing across different collation types.
### Does this PR introduce _any_ user-facing change?
No, this PR is focused on internal refactoring and testing enhancements for
collation-aware expression support.
### How was this patch tested?
Unit tests in CollationSupportSuite.java
E2E tests in CollationStringExpressionsSuite.scala
### Was this patch authored or co-authored using generative AI tooling?
Yes.
Closes #45978 from uros-db/SPARK-47410.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/util/CollationFactory.java | 54 +-
.../spark/sql/catalyst/util/CollationSupport.java | 174 ++++++
.../org/apache/spark/unsafe/types/UTF8String.java | 54 --
.../spark/unsafe/types/CollationSupportSuite.java | 266 +++++++++
.../unsafe/types/UTF8StringWithCollationSuite.java | 103 ----
.../expressions/codegen/CodeGenerator.scala | 3 +-
.../catalyst/expressions/stringExpressions.scala | 41 +-
.../sql/CollationRegexpExpressionsSuite.scala | 616 +++++++++------------
.../sql/CollationStringExpressionsSuite.scala | 179 ++++--
.../org/apache/spark/sql/CollationSuite.scala | 84 ---
10 files changed, 874 insertions(+), 700 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 72a6e574707f..ff7bc450f851 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -78,6 +78,14 @@ public final class CollationFactory {
*/
public final boolean supportsBinaryOrdering;
+ /**
+ * Support for Lowercase Equality implies that it is possible to check
equality on
+ * byte by byte level, but only after calling "UTF8String.toLowerCase" on
both arguments.
+ * This allows custom collation support for UTF8_BINARY_LCASE collation in
various Spark
+ * expressions, as this particular collation is not supported by the
external ICU library.
+ */
+ public final boolean supportsLowercaseEquality;
+
public Collation(
String collationName,
Collator collator,
@@ -85,7 +93,8 @@ public final class CollationFactory {
String version,
ToLongFunction<UTF8String> hashFunction,
boolean supportsBinaryEquality,
- boolean supportsBinaryOrdering) {
+ boolean supportsBinaryOrdering,
+ boolean supportsLowercaseEquality) {
this.collationName = collationName;
this.collator = collator;
this.comparator = comparator;
@@ -93,9 +102,12 @@ public final class CollationFactory {
this.hashFunction = hashFunction;
this.supportsBinaryEquality = supportsBinaryEquality;
this.supportsBinaryOrdering = supportsBinaryOrdering;
+ this.supportsLowercaseEquality = supportsLowercaseEquality;
// De Morgan's Law to check supportsBinaryOrdering =>
supportsBinaryEquality
assert(!supportsBinaryOrdering || supportsBinaryEquality);
+ // No Collation can simultaneously support binary equality and lowercase
equality
+ assert(!supportsBinaryEquality || !supportsLowercaseEquality);
if (supportsBinaryEquality) {
this.equalsFunction = UTF8String::equals;
@@ -112,7 +124,8 @@ public final class CollationFactory {
Collator collator,
String version,
boolean supportsBinaryEquality,
- boolean supportsBinaryOrdering) {
+ boolean supportsBinaryOrdering,
+ boolean supportsLowercaseEquality) {
this(
collationName,
collator,
@@ -120,7 +133,8 @@ public final class CollationFactory {
version,
s -> (long)collator.getCollationKey(s.toString()).hashCode(),
supportsBinaryEquality,
- supportsBinaryOrdering);
+ supportsBinaryOrdering,
+ supportsLowercaseEquality);
}
}
@@ -141,7 +155,8 @@ public final class CollationFactory {
"1.0",
s -> (long)s.hashCode(),
true,
- true);
+ true,
+ false);
// Case-insensitive UTF8 binary collation.
// TODO: Do in place comparisons instead of creating new strings.
@@ -152,17 +167,18 @@ public final class CollationFactory {
"1.0",
(s) -> (long)s.toLowerCase().hashCode(),
false,
- false);
+ false,
+ true);
// UNICODE case sensitive comparison (ROOT locale, in ICU).
collationTable[2] = new Collation(
- "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true,
false);
+ "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true,
false, false);
collationTable[2].collator.setStrength(Collator.TERTIARY);
collationTable[2].collator.freeze();
// UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary
strength).
collationTable[3] = new Collation(
- "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false,
false);
+ "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false,
false, false);
collationTable[3].collator.setStrength(Collator.SECONDARY);
collationTable[3].collator.freeze();
@@ -172,19 +188,31 @@ public final class CollationFactory {
}
/**
- * Auxiliary methods for collation aware string operations.
+ * Returns a StringSearch object for the given pattern and target strings,
under collation
+ * rules corresponding to the given collationId. The external ICU library
StringSearch object can
+ * be used to find occurrences of the pattern in the target string, while
respecting collation.
*/
-
public static StringSearch getStringSearch(
- final UTF8String left,
- final UTF8String right,
+ final UTF8String targetUTF8String,
+ final UTF8String patternUTF8String,
final int collationId) {
- String pattern = right.toString();
- CharacterIterator target = new StringCharacterIterator(left.toString());
+ String pattern = patternUTF8String.toString();
+ CharacterIterator target = new
StringCharacterIterator(targetUTF8String.toString());
Collator collator = CollationFactory.fetchCollation(collationId).collator;
return new StringSearch(pattern, target, (RuleBasedCollator) collator);
}
+ /**
+ * Returns a collation-unaware StringSearch object for the given pattern and
target strings.
+ * While this object does not respect collation, it can be used to find
occurrences of the pattern
+ * in the target string for UTF8_BINARY or UTF8_BINARY_LCASE (if arguments
are lowercased).
+ */
+ public static StringSearch getStringSearch(
+ final UTF8String targetUTF8String,
+ final UTF8String patternUTF8String) {
+ return new StringSearch(patternUTF8String.toString(),
targetUTF8String.toString());
+ }
+
/**
* Returns the collation id for the given collation name.
*/
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
new file mode 100644
index 000000000000..fe1952921b7f
--- /dev/null
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.util;
+
+import com.ibm.icu.text.StringSearch;
+
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * Static entry point for collation-aware expressions (StringExpressions,
RegexpExpressions, and
+ * other expressions that require custom collation support), as well as
private utility methods for
+ * collation-aware UTF8String operations needed to implement .
+ */
+public final class CollationSupport {
+
+ /**
+ * Collation-aware string expressions.
+ */
+
+ public static class Contains {
+ public static boolean exec(final UTF8String l, final UTF8String r, final
int collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ if (collation.supportsBinaryEquality) {
+ return execBinary(l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return execLowercase(l, r);
+ } else {
+ return execICU(l, r, collationId);
+ }
+ }
+ public static String genCode(final String l, final String r, final int
collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ String expr = "CollationSupport.Contains.exec";
+ if (collation.supportsBinaryEquality) {
+ return String.format(expr + "Binary(%s, %s)", l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return String.format(expr + "Lowercase(%s, %s)", l, r);
+ } else {
+ return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId);
+ }
+ }
+ public static boolean execBinary(final UTF8String l, final UTF8String r) {
+ return l.contains(r);
+ }
+ public static boolean execLowercase(final UTF8String l, final UTF8String
r) {
+ return l.toLowerCase().contains(r.toLowerCase());
+ }
+ public static boolean execICU(final UTF8String l, final UTF8String r,
+ final int collationId) {
+ if (r.numBytes() == 0) return true;
+ if (l.numBytes() == 0) return false;
+ StringSearch stringSearch = CollationFactory.getStringSearch(l, r,
collationId);
+ return stringSearch.first() != StringSearch.DONE;
+ }
+ }
+
+ public static class StartsWith {
+ public static boolean exec(final UTF8String l, final UTF8String r,
+ final int collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ if (collation.supportsBinaryEquality) {
+ return execBinary(l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return execLowercase(l, r);
+ } else {
+ return execICU(l, r, collationId);
+ }
+ }
+ public static String genCode(final String l, final String r, final int
collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ String expr = "CollationSupport.StartsWith.exec";
+ if (collation.supportsBinaryEquality) {
+ return String.format(expr + "Binary(%s, %s)", l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return String.format(expr + "Lowercase(%s, %s)", l, r);
+ } else {
+ return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId);
+ }
+ }
+ public static boolean execBinary(final UTF8String l, final UTF8String r) {
+ return l.startsWith(r);
+ }
+ public static boolean execLowercase(final UTF8String l, final UTF8String
r) {
+ return l.toLowerCase().startsWith(r.toLowerCase());
+ }
+ public static boolean execICU(final UTF8String l, final UTF8String r,
+ final int collationId) {
+ return CollationAwareUTF8String.matchAt(l, r, 0, collationId);
+ }
+ }
+
+ public static class EndsWith {
+ public static boolean exec(final UTF8String l, final UTF8String r, final
int collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ if (collation.supportsBinaryEquality) {
+ return execBinary(l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return execLowercase(l, r);
+ } else {
+ return execICU(l, r, collationId);
+ }
+ }
+ public static String genCode(final String l, final String r, final int
collationId) {
+ CollationFactory.Collation collation =
CollationFactory.fetchCollation(collationId);
+ String expr = "CollationSupport.EndsWith.exec";
+ if (collation.supportsBinaryEquality) {
+ return String.format(expr + "Binary(%s, %s)", l, r);
+ } else if (collation.supportsLowercaseEquality) {
+ return String.format(expr + "Lowercase(%s, %s)", l, r);
+ } else {
+ return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId);
+ }
+ }
+ public static boolean execBinary(final UTF8String l, final UTF8String r) {
+ return l.endsWith(r);
+ }
+ public static boolean execLowercase(final UTF8String l, final UTF8String
r) {
+ return l.toLowerCase().endsWith(r.toLowerCase());
+ }
+ public static boolean execICU(final UTF8String l, final UTF8String r,
+ final int collationId) {
+ return CollationAwareUTF8String.matchAt(l, r, l.numBytes() -
r.numBytes(), collationId);
+ }
+ }
+
+ // TODO: Add more collation-aware string expressions.
+
+ /**
+ * Collation-aware regexp expressions.
+ */
+
+ // TODO: Add more collation-aware regexp expressions.
+
+ /**
+ * Other collation-aware expressions.
+ */
+
+ // TODO: Add other collation-aware expressions.
+
+ /**
+ * Utility class for collation-aware UTF8String operations.
+ */
+
+ private static class CollationAwareUTF8String {
+
+ private static boolean matchAt(final UTF8String target, final UTF8String
pattern,
+ final int pos, final int collationId) {
+ if (pattern.numChars() + pos > target.numChars() || pos < 0) {
+ return false;
+ }
+ if (pattern.numBytes() == 0 || target.numBytes() == 0) {
+ return pattern.numBytes() == 0;
+ }
+ return CollationFactory.getStringSearch(target.substring(
+ pos, pos + pattern.numChars()), pattern, collationId).last() == 0;
+ }
+
+ }
+
+}
diff --git
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 2006efb07a04..2009f1d20442 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -30,7 +30,6 @@ import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
-import com.ibm.icu.text.StringSearch;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UTF8StringBuilder;
@@ -342,28 +341,6 @@ public final class UTF8String implements
Comparable<UTF8String>, Externalizable,
return false;
}
- public boolean contains(final UTF8String substring, int collationId) {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- return this.contains(substring);
- }
- if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
- return this.toLowerCase().contains(substring.toLowerCase());
- }
- return collatedContains(substring, collationId);
- }
-
- private boolean collatedContains(final UTF8String substring, int
collationId) {
- if (substring.numBytes == 0) return true;
- if (this.numBytes == 0) return false;
- StringSearch stringSearch = CollationFactory.getStringSearch(this,
substring, collationId);
- while (stringSearch.next() != StringSearch.DONE) {
- if (stringSearch.getMatchLength() == stringSearch.getPattern().length())
{
- return true;
- }
- }
- return false;
- }
-
/**
* Returns the byte at position `i`.
*/
@@ -378,45 +355,14 @@ public final class UTF8String implements
Comparable<UTF8String>, Externalizable,
return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset,
s.numBytes);
}
- private boolean matchAt(final UTF8String s, int pos, int collationId) {
- if (s.numChars() + pos > this.numChars() || pos < 0) {
- return false;
- }
- if (s.numBytes == 0 || this.numBytes == 0) {
- return s.numBytes == 0;
- }
- return CollationFactory.getStringSearch(this.substring(pos, pos +
s.numChars()),
- s, collationId).last() == 0;
- }
-
public boolean startsWith(final UTF8String prefix) {
return matchAt(prefix, 0);
}
- public boolean startsWith(final UTF8String prefix, int collationId) {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- return this.startsWith(prefix);
- }
- if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
- return this.toLowerCase().startsWith(prefix.toLowerCase());
- }
- return matchAt(prefix, 0, collationId);
- }
-
public boolean endsWith(final UTF8String suffix) {
return matchAt(suffix, numBytes - suffix.numBytes);
}
- public boolean endsWith(final UTF8String suffix, int collationId) {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- return this.endsWith(suffix);
- }
- if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
- return this.toLowerCase().endsWith(suffix.toLowerCase());
- }
- return matchAt(suffix, numBytes - suffix.numBytes, collationId);
- }
-
/**
* Returns the upper case of this string
*/
diff --git
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
new file mode 100644
index 000000000000..bfb696c35fff
--- /dev/null
+++
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.unsafe.types;
+
+import org.apache.spark.SparkException;
+import org.apache.spark.sql.catalyst.util.CollationFactory;
+import org.apache.spark.sql.catalyst.util.CollationSupport;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+
+public class CollationSupportSuite {
+
+ /**
+ * Collation-aware string expressions.
+ */
+
+ private void assertContains(String pattern, String target, String
collationName, boolean value)
+ throws SparkException {
+ UTF8String l = UTF8String.fromString(pattern);
+ UTF8String r = UTF8String.fromString(target);
+ int collationId = CollationFactory.collationNameToId(collationName);
+ assertEquals(CollationSupport.Contains.exec(l, r, collationId), value);
+ }
+
+ @Test
+ public void testContains() throws SparkException {
+ // Edge cases
+ assertContains("", "", "UTF8_BINARY", true);
+ assertContains("c", "", "UTF8_BINARY", true);
+ assertContains("", "c", "UTF8_BINARY", false);
+ assertContains("", "", "UNICODE", true);
+ assertContains("c", "", "UNICODE", true);
+ assertContains("", "c", "UNICODE", false);
+ assertContains("", "", "UTF8_BINARY_LCASE", true);
+ assertContains("c", "", "UTF8_BINARY_LCASE", true);
+ assertContains("", "c", "UTF8_BINARY_LCASE", false);
+ assertContains("", "", "UNICODE_CI", true);
+ assertContains("c", "", "UNICODE_CI", true);
+ assertContains("", "c", "UNICODE_CI", false);
+ // Basic tests
+ assertContains("abcde", "bcd", "UTF8_BINARY", true);
+ assertContains("abcde", "bde", "UTF8_BINARY", false);
+ assertContains("abcde", "fgh", "UTF8_BINARY", false);
+ assertContains("abcde", "abcde", "UNICODE", true);
+ assertContains("abcde", "aBcDe", "UNICODE", false);
+ assertContains("abcde", "fghij", "UNICODE", false);
+ assertContains("abcde", "C", "UTF8_BINARY_LCASE", true);
+ assertContains("abcde", "AbCdE", "UTF8_BINARY_LCASE", true);
+ assertContains("abcde", "X", "UTF8_BINARY_LCASE", false);
+ assertContains("abcde", "c", "UNICODE_CI", true);
+ assertContains("abcde", "bCD", "UNICODE_CI", true);
+ assertContains("abcde", "123", "UNICODE_CI", false);
+ // Case variation
+ assertContains("aBcDe", "bcd", "UTF8_BINARY", false);
+ assertContains("aBcDe", "BcD", "UTF8_BINARY", true);
+ assertContains("aBcDe", "abcde", "UNICODE", false);
+ assertContains("aBcDe", "aBcDe", "UNICODE", true);
+ assertContains("aBcDe", "bcd", "UTF8_BINARY_LCASE", true);
+ assertContains("aBcDe", "BCD", "UTF8_BINARY_LCASE", true);
+ assertContains("aBcDe", "abcde", "UNICODE_CI", true);
+ assertContains("aBcDe", "AbCdE", "UNICODE_CI", true);
+ // Accent variation
+ assertContains("aBcDe", "bćd", "UTF8_BINARY", false);
+ assertContains("aBcDe", "BćD", "UTF8_BINARY", false);
+ assertContains("aBcDe", "abćde", "UNICODE", false);
+ assertContains("aBcDe", "aBćDe", "UNICODE", false);
+ assertContains("aBcDe", "bćd", "UTF8_BINARY_LCASE", false);
+ assertContains("aBcDe", "BĆD", "UTF8_BINARY_LCASE", false);
+ assertContains("aBcDe", "abćde", "UNICODE_CI", false);
+ assertContains("aBcDe", "AbĆdE", "UNICODE_CI", false);
+ // Variable byte length characters
+ assertContains("ab世De", "b世D", "UTF8_BINARY", true);
+ assertContains("ab世De", "B世d", "UTF8_BINARY", false);
+ assertContains("äbćδe", "bćδ", "UTF8_BINARY", true);
+ assertContains("äbćδe", "BcΔ", "UTF8_BINARY", false);
+ assertContains("ab世De", "ab世De", "UNICODE", true);
+ assertContains("ab世De", "AB世dE", "UNICODE", false);
+ assertContains("äbćδe", "äbćδe", "UNICODE", true);
+ assertContains("äbćδe", "ÄBcΔÉ", "UNICODE", false);
+ assertContains("ab世De", "b世D", "UTF8_BINARY_LCASE", true);
+ assertContains("ab世De", "B世d", "UTF8_BINARY_LCASE", true);
+ assertContains("äbćδe", "bćδ", "UTF8_BINARY_LCASE", true);
+ assertContains("äbćδe", "BcΔ", "UTF8_BINARY_LCASE", false);
+ assertContains("ab世De", "ab世De", "UNICODE_CI", true);
+ assertContains("ab世De", "AB世dE", "UNICODE_CI", true);
+ assertContains("äbćδe", "ÄbćδE", "UNICODE_CI", true);
+ assertContains("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false);
+ }
+
+ private void assertStartsWith(String pattern, String prefix, String
collationName, boolean value)
+ throws SparkException {
+ UTF8String l = UTF8String.fromString(pattern);
+ UTF8String r = UTF8String.fromString(prefix);
+ int collationId = CollationFactory.collationNameToId(collationName);
+ assertEquals(CollationSupport.StartsWith.exec(l, r, collationId), value);
+ }
+
+ @Test
+ public void testStartsWith() throws SparkException {
+ // Edge cases
+ assertStartsWith("", "", "UTF8_BINARY", true);
+ assertStartsWith("c", "", "UTF8_BINARY", true);
+ assertStartsWith("", "c", "UTF8_BINARY", false);
+ assertStartsWith("", "", "UNICODE", true);
+ assertStartsWith("c", "", "UNICODE", true);
+ assertStartsWith("", "c", "UNICODE", false);
+ assertStartsWith("", "", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("c", "", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("", "c", "UTF8_BINARY_LCASE", false);
+ assertStartsWith("", "", "UNICODE_CI", true);
+ assertStartsWith("c", "", "UNICODE_CI", true);
+ assertStartsWith("", "c", "UNICODE_CI", false);
+ // Basic tests
+ assertStartsWith("abcde", "abc", "UTF8_BINARY", true);
+ assertStartsWith("abcde", "abd", "UTF8_BINARY", false);
+ assertStartsWith("abcde", "fgh", "UTF8_BINARY", false);
+ assertStartsWith("abcde", "abcde", "UNICODE", true);
+ assertStartsWith("abcde", "aBcDe", "UNICODE", false);
+ assertStartsWith("abcde", "fghij", "UNICODE", false);
+ assertStartsWith("abcde", "A", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("abcde", "AbCdE", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("abcde", "X", "UTF8_BINARY_LCASE", false);
+ assertStartsWith("abcde", "a", "UNICODE_CI", true);
+ assertStartsWith("abcde", "aBC", "UNICODE_CI", true);
+ assertStartsWith("abcde", "123", "UNICODE_CI", false);
+ // Case variation
+ assertStartsWith("aBcDe", "abc", "UTF8_BINARY", false);
+ assertStartsWith("aBcDe", "aBc", "UTF8_BINARY", true);
+ assertStartsWith("aBcDe", "abcde", "UNICODE", false);
+ assertStartsWith("aBcDe", "aBcDe", "UNICODE", true);
+ assertStartsWith("aBcDe", "abc", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("aBcDe", "ABC", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("aBcDe", "abcde", "UNICODE_CI", true);
+ assertStartsWith("aBcDe", "AbCdE", "UNICODE_CI", true);
+ // Accent variation
+ assertStartsWith("aBcDe", "abć", "UTF8_BINARY", false);
+ assertStartsWith("aBcDe", "aBć", "UTF8_BINARY", false);
+ assertStartsWith("aBcDe", "abćde", "UNICODE", false);
+ assertStartsWith("aBcDe", "aBćDe", "UNICODE", false);
+ assertStartsWith("aBcDe", "abć", "UTF8_BINARY_LCASE", false);
+ assertStartsWith("aBcDe", "ABĆ", "UTF8_BINARY_LCASE", false);
+ assertStartsWith("aBcDe", "abćde", "UNICODE_CI", false);
+ assertStartsWith("aBcDe", "AbĆdE", "UNICODE_CI", false);
+ // Variable byte length characters
+ assertStartsWith("ab世De", "ab世", "UTF8_BINARY", true);
+ assertStartsWith("ab世De", "aB世", "UTF8_BINARY", false);
+ assertStartsWith("äbćδe", "äbć", "UTF8_BINARY", true);
+ assertStartsWith("äbćδe", "äBc", "UTF8_BINARY", false);
+ assertStartsWith("ab世De", "ab世De", "UNICODE", true);
+ assertStartsWith("ab世De", "AB世dE", "UNICODE", false);
+ assertStartsWith("äbćδe", "äbćδe", "UNICODE", true);
+ assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE", false);
+ assertStartsWith("ab世De", "ab世", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("ab世De", "aB世", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("äbćδe", "äbć", "UTF8_BINARY_LCASE", true);
+ assertStartsWith("äbćδe", "äBc", "UTF8_BINARY_LCASE", false);
+ assertStartsWith("ab世De", "ab世De", "UNICODE_CI", true);
+ assertStartsWith("ab世De", "AB世dE", "UNICODE_CI", true);
+ assertStartsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true);
+ assertStartsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false);
+ }
+
+ private void assertEndsWith(String pattern, String suffix, String
collationName, boolean value)
+ throws SparkException {
+ UTF8String l = UTF8String.fromString(pattern);
+ UTF8String r = UTF8String.fromString(suffix);
+ int collationId = CollationFactory.collationNameToId(collationName);
+ assertEquals(CollationSupport.EndsWith.exec(l, r, collationId), value);
+ }
+
+ @Test
+ public void testEndsWith() throws SparkException {
+ // Edge cases
+ assertEndsWith("", "", "UTF8_BINARY", true);
+ assertEndsWith("c", "", "UTF8_BINARY", true);
+ assertEndsWith("", "c", "UTF8_BINARY", false);
+ assertEndsWith("", "", "UNICODE", true);
+ assertEndsWith("c", "", "UNICODE", true);
+ assertEndsWith("", "c", "UNICODE", false);
+ assertEndsWith("", "", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("c", "", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("", "c", "UTF8_BINARY_LCASE", false);
+ assertEndsWith("", "", "UNICODE_CI", true);
+ assertEndsWith("c", "", "UNICODE_CI", true);
+ assertEndsWith("", "c", "UNICODE_CI", false);
+ // Basic tests
+ assertEndsWith("abcde", "cde", "UTF8_BINARY", true);
+ assertEndsWith("abcde", "bde", "UTF8_BINARY", false);
+ assertEndsWith("abcde", "fgh", "UTF8_BINARY", false);
+ assertEndsWith("abcde", "abcde", "UNICODE", true);
+ assertEndsWith("abcde", "aBcDe", "UNICODE", false);
+ assertEndsWith("abcde", "fghij", "UNICODE", false);
+ assertEndsWith("abcde", "E", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("abcde", "AbCdE", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("abcde", "X", "UTF8_BINARY_LCASE", false);
+ assertEndsWith("abcde", "e", "UNICODE_CI", true);
+ assertEndsWith("abcde", "CDe", "UNICODE_CI", true);
+ assertEndsWith("abcde", "123", "UNICODE_CI", false);
+ // Case variation
+ assertEndsWith("aBcDe", "cde", "UTF8_BINARY", false);
+ assertEndsWith("aBcDe", "cDe", "UTF8_BINARY", true);
+ assertEndsWith("aBcDe", "abcde", "UNICODE", false);
+ assertEndsWith("aBcDe", "aBcDe", "UNICODE", true);
+ assertEndsWith("aBcDe", "cde", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("aBcDe", "CDE", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("aBcDe", "abcde", "UNICODE_CI", true);
+ assertEndsWith("aBcDe", "AbCdE", "UNICODE_CI", true);
+ // Accent variation
+ assertEndsWith("aBcDe", "ćde", "UTF8_BINARY", false);
+ assertEndsWith("aBcDe", "ćDe", "UTF8_BINARY", false);
+ assertEndsWith("aBcDe", "abćde", "UNICODE", false);
+ assertEndsWith("aBcDe", "aBćDe", "UNICODE", false);
+ assertEndsWith("aBcDe", "ćde", "UTF8_BINARY_LCASE", false);
+ assertEndsWith("aBcDe", "ĆDE", "UTF8_BINARY_LCASE", false);
+ assertEndsWith("aBcDe", "abćde", "UNICODE_CI", false);
+ assertEndsWith("aBcDe", "AbĆdE", "UNICODE_CI", false);
+ // Variable byte length characters
+ assertEndsWith("ab世De", "世De", "UTF8_BINARY", true);
+ assertEndsWith("ab世De", "世dE", "UTF8_BINARY", false);
+ assertEndsWith("äbćδe", "ćδe", "UTF8_BINARY", true);
+ assertEndsWith("äbćδe", "cΔé", "UTF8_BINARY", false);
+ assertEndsWith("ab世De", "ab世De", "UNICODE", true);
+ assertEndsWith("ab世De", "AB世dE", "UNICODE", false);
+ assertEndsWith("äbćδe", "äbćδe", "UNICODE", true);
+ assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE", false);
+ assertEndsWith("ab世De", "世De", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("ab世De", "世dE", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("äbćδe", "ćδe", "UTF8_BINARY_LCASE", true);
+ assertEndsWith("äbćδe", "cδE", "UTF8_BINARY_LCASE", false);
+ assertEndsWith("ab世De", "ab世De", "UNICODE_CI", true);
+ assertEndsWith("ab世De", "AB世dE", "UNICODE_CI", true);
+ assertEndsWith("äbćδe", "ÄbćδE", "UNICODE_CI", true);
+ assertEndsWith("äbćδe", "ÄBcΔÉ", "UNICODE_CI", false);
+ }
+
+ // TODO: Test more collation-aware string expressions.
+
+ /**
+ * Collation-aware regexp expressions.
+ */
+
+ // TODO: Test more collation-aware regexp expressions.
+
+ /**
+ * Other collation-aware expressions.
+ */
+
+ // TODO: Test other collation-aware expressions.
+
+}
diff --git
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java
deleted file mode 100644
index b60da7b945a4..000000000000
---
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringWithCollationSuite.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.unsafe.types;
-
-import org.apache.spark.SparkException;
-import org.apache.spark.sql.catalyst.util.CollationFactory;
-import org.junit.jupiter.api.Test;
-
-import static org.junit.jupiter.api.Assertions.*;
-
-
-public class UTF8StringWithCollationSuite {
-
- private void assertStartsWith(String pattern, String prefix, String
collationName, boolean value)
- throws SparkException {
-
assertEquals(UTF8String.fromString(pattern).startsWith(UTF8String.fromString(prefix),
- CollationFactory.collationNameToId(collationName)), value);
- }
-
- private void assertEndsWith(String pattern, String suffix, String
collationName, boolean value)
- throws SparkException {
-
assertEquals(UTF8String.fromString(pattern).endsWith(UTF8String.fromString(suffix),
- CollationFactory.collationNameToId(collationName)), value);
- }
-
- @Test
- public void startsWithTest() throws SparkException {
- assertStartsWith("", "", "UTF8_BINARY", true);
- assertStartsWith("c", "", "UTF8_BINARY", true);
- assertStartsWith("", "c", "UTF8_BINARY", false);
- assertStartsWith("abcde", "a", "UTF8_BINARY", true);
- assertStartsWith("abcde", "A", "UTF8_BINARY", false);
- assertStartsWith("abcde", "bcd", "UTF8_BINARY", false);
- assertStartsWith("abcde", "BCD", "UTF8_BINARY", false);
- assertStartsWith("", "", "UNICODE", true);
- assertStartsWith("c", "", "UNICODE", true);
- assertStartsWith("", "c", "UNICODE", false);
- assertStartsWith("abcde", "a", "UNICODE", true);
- assertStartsWith("abcde", "A", "UNICODE", false);
- assertStartsWith("abcde", "bcd", "UNICODE", false);
- assertStartsWith("abcde", "BCD", "UNICODE", false);
- assertStartsWith("", "", "UTF8_BINARY_LCASE", true);
- assertStartsWith("c", "", "UTF8_BINARY_LCASE", true);
- assertStartsWith("", "c", "UTF8_BINARY_LCASE", false);
- assertStartsWith("abcde", "a", "UTF8_BINARY_LCASE", true);
- assertStartsWith("abcde", "A", "UTF8_BINARY_LCASE", true);
- assertStartsWith("abcde", "abc", "UTF8_BINARY_LCASE", true);
- assertStartsWith("abcde", "BCD", "UTF8_BINARY_LCASE", false);
- assertStartsWith("", "", "UNICODE_CI", true);
- assertStartsWith("c", "", "UNICODE_CI", true);
- assertStartsWith("", "c", "UNICODE_CI", false);
- assertStartsWith("abcde", "a", "UNICODE_CI", true);
- assertStartsWith("abcde", "A", "UNICODE_CI", true);
- assertStartsWith("abcde", "abc", "UNICODE_CI", true);
- assertStartsWith("abcde", "BCD", "UNICODE_CI", false);
- }
-
- @Test
- public void endsWithTest() throws SparkException {
- assertEndsWith("", "", "UTF8_BINARY", true);
- assertEndsWith("c", "", "UTF8_BINARY", true);
- assertEndsWith("", "c", "UTF8_BINARY", false);
- assertEndsWith("abcde", "e", "UTF8_BINARY", true);
- assertEndsWith("abcde", "E", "UTF8_BINARY", false);
- assertEndsWith("abcde", "bcd", "UTF8_BINARY", false);
- assertEndsWith("abcde", "BCD", "UTF8_BINARY", false);
- assertEndsWith("", "", "UNICODE", true);
- assertEndsWith("c", "", "UNICODE", true);
- assertEndsWith("", "c", "UNICODE", false);
- assertEndsWith("abcde", "e", "UNICODE", true);
- assertEndsWith("abcde", "E", "UNICODE", false);
- assertEndsWith("abcde", "bcd", "UNICODE", false);
- assertEndsWith("abcde", "BCD", "UNICODE", false);
- assertEndsWith("", "", "UTF8_BINARY_LCASE", true);
- assertEndsWith("c", "", "UTF8_BINARY_LCASE", true);
- assertEndsWith("", "c", "UTF8_BINARY_LCASE", false);
- assertEndsWith("abcde", "e", "UTF8_BINARY_LCASE", true);
- assertEndsWith("abcde", "E", "UTF8_BINARY_LCASE", true);
- assertEndsWith("abcde", "cde", "UTF8_BINARY_LCASE", true);
- assertEndsWith("abcde", "BCD", "UTF8_BINARY_LCASE", false);
- assertEndsWith("", "", "UNICODE_CI", true);
- assertEndsWith("c", "", "UNICODE_CI", true);
- assertEndsWith("", "c", "UNICODE_CI", false);
- assertEndsWith("abcde", "e", "UNICODE_CI", true);
- assertEndsWith("abcde", "E", "UNICODE_CI", true);
- assertEndsWith("abcde", "cde", "UNICODE_CI", true);
- assertEndsWith("abcde", "BCD", "UNICODE_CI", false);
- }
-}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 01f22720dd12..5aa766a60c10 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -41,7 +41,7 @@ import
org.apache.spark.sql.catalyst.encoders.HashableWeakReference
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
MapData, SQLOrderingUtil, UnsafeRowUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -1531,6 +1531,7 @@ object CodeGenerator extends Logging {
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName,
classOf[CollationFactory].getName,
+ classOf[CollationSupport].getName,
QueryExecutionErrors.getClass.getName.stripSuffix("$")
)
evaluator.setExtendedClass(classOf[GeneratedClass])
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index cf6c9d4f1d94..9c862581bfe4 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern,
UPPER_OR_LOWER}
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationSupport,
GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
@@ -591,18 +591,11 @@ object ContainsExpressionBuilder extends
StringBinaryPredicateExpressionBuilderB
case class Contains(left: Expression, right: Expression) extends
StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- l.contains(r)
- } else {
- l.contains(r, collationId)
- }
+ CollationSupport.Contains.exec(l, r, collationId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2)")
- } else {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2, $collationId)")
- }
+ defineCodeGen(ctx, ev, (c1, c2) =>
+ CollationSupport.Contains.genCode(c1, c2, collationId))
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Contains = copy(left =
newLeft, right = newRight)
@@ -638,19 +631,12 @@ object StartsWithExpressionBuilder extends
StringBinaryPredicateExpressionBuilde
case class StartsWith(left: Expression, right: Expression) extends
StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- l.startsWith(r)
- } else {
- l.startsWith(r, collationId)
- }
+ CollationSupport.StartsWith.exec(l, r, collationId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2)")
- } else {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2, $collationId)")
- }
+ defineCodeGen(ctx, ev, (c1, c2) =>
+ CollationSupport.StartsWith.genCode(c1, c2, collationId))
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StartsWith = copy(left =
newLeft, right = newRight)
@@ -686,19 +672,12 @@ object EndsWithExpressionBuilder extends
StringBinaryPredicateExpressionBuilderB
case class EndsWith(left: Expression, right: Expression) extends
StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- l.endsWith(r)
- } else {
- l.endsWith(r, collationId)
- }
+ CollationSupport.EndsWith.exec(l, r, collationId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2)")
- } else {
- defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2, $collationId)")
- }
+ defineCodeGen(ctx, ev, (c1, c2) =>
+ CollationSupport.EndsWith.genCode(c1, c2, collationId))
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): EndsWith = copy(left =
newLeft, right = newRight)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala
index c547068a03c3..0876425847bb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationRegexpExpressionsSuite.scala
@@ -20,420 +20,310 @@ package org.apache.spark.sql
import scala.collection.immutable.Seq
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{ArrayType, BooleanType, IntegerType,
StringType}
class CollationRegexpExpressionsSuite
extends QueryTest
with SharedSparkSession
with ExpressionEvalHelper {
- case class CollationTestCase[R](s1: String, s2: String, collation: String,
expectedResult: R)
- case class CollationTestFail[R](s1: String, s2: String, collation: String)
-
- test("Support Like string expression with Collation") {
- def prepareLike(
- input: String,
- regExp: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- Like(inputExpr, regExpExpr, '\\')
- }
+ test("Support Like string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABC", "%B%", "UTF8_BINARY", true)
- )
- checks.foreach(ct =>
- checkEvaluation(prepareLike(ct.s1, ct.s2, ct.collation),
ct.expectedResult))
+ case class LikeTestCase[R](l: String, r: String, c: String, result: R)
+ val testCases = Seq(
+ LikeTestCase("ABC", "%B%", "UTF8_BINARY", true)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABC", "%B%", "UNICODE"),
- CollationTestFail("ABC", "%b%", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareLike(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class LikeTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ LikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"),
+ LikeTestFail("ABC", "%B%", "UNICODE"),
+ LikeTestFail("ABC", "%b%", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT like(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support ILike string expression with Collation") {
- def prepareILike(
- input: String,
- regExp: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- ILike(inputExpr, regExpExpr, '\\').replacement
- }
-
+ test("Support ILike string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABC", "%b%", "UTF8_BINARY", true)
- )
- checks.foreach(ct =>
- checkEvaluation(prepareILike(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class ILikeTestCase[R](l: String, r: String, c: String, result: R)
+ val testCases = Seq(
+ ILikeTestCase("ABC", "%b%", "UTF8_BINARY", true)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT ilike(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABC", "%b%", "UNICODE"),
- CollationTestFail("ABC", "%b%", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareILike(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""lower('${ct.s1}' collate ${ct.collation})"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class ILikeTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ ILikeTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"),
+ ILikeTestFail("ABC", "%b%", "UNICODE"),
+ ILikeTestFail("ABC", "%b%", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT ilike(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RLike string expression with Collation") {
- def prepareRLike(
- input: String,
- regExp: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RLike(inputExpr, regExpExpr)
- }
+ test("Support RLike string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABC", ".B.", "UTF8_BINARY", true)
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRLike(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RLikeTestCase[R](l: String, r: String, c: String, result: R)
+ val testCases = Seq(
+ RLikeTestCase("ABC", ".B.", "UTF8_BINARY", true)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT rlike(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABC", ".B.", "UNICODE"),
- CollationTestFail("ABC", ".b.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRLike(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RLikeTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RLikeTestFail("ABC", ".b.", "UTF8_BINARY_LCASE"),
+ RLikeTestFail("ABC", ".B.", "UNICODE"),
+ RLikeTestFail("ABC", ".b.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT rlike(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support StringSplit string expression with Collation") {
- def prepareStringSplit(
- input: String,
- splitBy: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val splitByExpr = Literal.create(splitBy, StringType(collationId))
- StringSplit(inputExpr, splitByExpr, Literal(-1))
- }
-
+ test("Support StringSplit string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C"))
- )
- checks.foreach(ct =>
- checkEvaluation(prepareStringSplit(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class StringSplitTestCase[R](l: String, r: String, c: String, result:
R)
+ val testCases = Seq(
+ StringSplitTestCase("ABC", "[B]", "UTF8_BINARY", Seq("A", "C"))
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT split(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+
assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c))))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABC", "[B]", "UNICODE"),
- CollationTestFail("ABC", "[b]", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareStringSplit(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class StringSplitTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ StringSplitTestFail("ABC", "[b]", "UTF8_BINARY_LCASE"),
+ StringSplitTestFail("ABC", "[B]", "UNICODE"),
+ StringSplitTestFail("ABC", "[b]", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT split(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpReplace string expression with Collation") {
- def prepareRegExpReplace(
- input: String,
- regExp: String,
- collation: String): RegExpReplace = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpReplace(inputExpr, regExpExpr, Literal.create("FFF",
StringType(collationId)))
- }
-
+ test("Support RegExpReplace string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE")
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpReplace(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpReplaceTestCase[R](l: String, r: String, c: String,
result: R)
+ val testCases = Seq(
+ RegExpReplaceTestCase("ABCDE", ".C.", "UTF8_BINARY", "AFFFE")
+ )
+ testCases.foreach(t => {
+ val query =
+ s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'), 'FFF')"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpReplace(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpReplaceTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpReplaceTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpReplaceTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpReplaceTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query =
+ s"SELECT regexp_replace(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'), 'FFF')"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpExtract string expression with Collation") {
- def prepareRegExpExtract(
- input: String,
- regExp: String,
- collation: String): RegExpExtract = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpExtract(inputExpr, regExpExpr, Literal(0))
- }
-
+ test("Support RegExpExtract string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD")
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpExtract(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpExtractTestCase[R](l: String, r: String, c: String,
result: R)
+ val testCases = Seq(
+ RegExpExtractTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD")
+ )
+ testCases.foreach(t => {
+ val query =
+ s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'), 0)"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpExtract(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpExtractTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpExtractTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpExtractTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpExtractTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query =
+ s"SELECT regexp_extract(collate('${t.l}', '${t.c}'), collate('${t.r}',
'${t.c}'), 0)"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpExtractAll string expression with Collation") {
- def prepareRegExpExtractAll(
- input: String,
- regExp: String,
- collation: String): RegExpExtractAll = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpExtractAll(inputExpr, regExpExpr, Literal(0))
- }
-
+ test("Support RegExpExtractAll string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD"))
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpExtractAllTestCase[R](l: String, r: String, c: String,
result: R)
+ val testCases = Seq(
+ RegExpExtractAllTestCase("ABCDE", ".C.", "UTF8_BINARY", Seq("BCD"))
+ )
+ testCases.foreach(t => {
+ val query =
+ s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'), 0)"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+
assert(sql(query).schema.fields.head.dataType.sameType(ArrayType(StringType(t.c))))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpExtractAll(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpExtractAllTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpExtractAllTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpExtractAllTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpExtractAllTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query =
+ s"SELECT regexp_extract_all(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'), 0)"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpCount string expression with Collation") {
- def prepareRegExpCount(
- input: String,
- regExp: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpCount(inputExpr, regExpExpr).replacement
- }
-
+ test("Support RegExpCount string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 1)
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpCount(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpCountTestCase[R](l: String, r: String, c: String, result:
R)
+ val testCases = Seq(
+ RegExpCountTestCase("ABCDE", ".C.", "UTF8_BINARY", 1)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpCount(ct.s1, ct.s2,
ct.collation).asInstanceOf[Size].child
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpCountTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpCountTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpCountTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpCountTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT regexp_count(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpSubStr string expression with Collation") {
- def prepareRegExpSubStr(
- input: String,
- regExp: String,
- collation: String): Expression = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpSubStr(inputExpr, regExpExpr).replacement.asInstanceOf[NullIf].left
- }
-
+ test("Support RegExpSubStr string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD")
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpSubStrTestCase[R](l: String, r: String, c: String,
result: R)
+ val testCases = Seq(
+ RegExpSubStrTestCase("ABCDE", ".C.", "UTF8_BINARY", "BCD")
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpSubStr(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpSubStrTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpSubStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpSubStrTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpSubStrTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT regexp_substr(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
- test("Support RegExpInStr string expression with Collation") {
- def prepareRegExpInStr(
- input: String,
- regExp: String,
- collation: String): RegExpInStr = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExpr = Literal.create(input, StringType(collationId))
- val regExpExpr = Literal.create(regExp, StringType(collationId))
- RegExpInStr(inputExpr, regExpExpr, Literal(0))
- }
-
+ test("Support RegExpInStr string expression with collation") {
// Supported collations
- val checks = Seq(
- CollationTestCase("ABCDE", ".C.", "UTF8_BINARY", 2)
- )
- checks.foreach(ct =>
- checkEvaluation(prepareRegExpInStr(ct.s1, ct.s2, ct.collation),
ct.expectedResult)
- )
+ case class RegExpInStrTestCase[R](l: String, r: String, c: String, result:
R)
+ val testCases = Seq(
+ RegExpInStrTestCase("ABCDE", ".C.", "UTF8_BINARY", 2)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(IntegerType))
+ // TODO: Implicit casting (not currently supported)
+ })
// Unsupported collations
- val fails = Seq(
- CollationTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABCDE", ".C.", "UNICODE"),
- CollationTestFail("ABCDE", ".c.", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareRegExpInStr(ct.s1, ct.s2, ct.collation)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""'${ct.s1}' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
- )
+ case class RegExpInStrTestFail(l: String, r: String, c: String)
+ val failCases = Seq(
+ RegExpInStrTestFail("ABCDE", ".c.", "UTF8_BINARY_LCASE"),
+ RegExpInStrTestFail("ABCDE", ".C.", "UNICODE"),
+ RegExpInStrTestFail("ABCDE", ".c.", "UNICODE_CI")
+ )
+ failCases.foreach(t => {
+ val query = s"SELECT regexp_instr(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ val unsupportedCollation = intercept[AnalysisException] {
+ sql(query)
+ }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // TODO: Collation mismatch (not currently supported)
}
+
}
class CollationRegexpExpressionsANSISuite extends
CollationRegexpExpressionsSuite {
override protected def sparkConf: SparkConf =
super.sparkConf.set(SQLConf.ANSI_ENABLED, true)
+
+ // TODO: If needed, add more tests for other regexp expressions (with ANSI
mode enabled)
+
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index c26f3ae02255..97dea6697541 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -20,80 +20,157 @@ package org.apache.spark.sql
import scala.collection.immutable.Seq
import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
-import org.apache.spark.sql.catalyst.expressions.{Collation, ConcatWs,
ExpressionEvalHelper, Literal, StringRepeat}
-import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{BooleanType, StringType}
class CollationStringExpressionsSuite
extends QueryTest
with SharedSparkSession
with ExpressionEvalHelper {
- case class CollationTestCase[R](s1: String, s2: String, collation: String,
expectedResult: R)
- case class CollationTestFail[R](s1: String, s2: String, collation: String)
-
-
- test("Support ConcatWs string expression with Collation") {
- def prepareConcatWs(
- sep: String,
- collation: String,
- inputs: Any*): ConcatWs = {
- val collationId = CollationFactory.collationNameToId(collation)
- val inputExprs = inputs.map(s => Literal.create(s,
StringType(collationId)))
- val sepExpr = Literal.create(sep, StringType(collationId))
- ConcatWs(sepExpr +: inputExprs)
- }
- // Supported Collations
- val checks = Seq(
- CollationTestCase("Spark", "SQL", "UTF8_BINARY", "Spark SQL")
+ test("Support ConcatWs string expression with collation") {
+ // Supported collations
+ case class ConcatWsTestCase[R](s: String, a: Array[String], c: String,
result: R)
+ val testCases = Seq(
+ ConcatWsTestCase(" ", Array("Spark", "SQL"), "UTF8_BINARY", "Spark SQL")
)
- checks.foreach(ct =>
- checkEvaluation(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2),
ct.expectedResult)
+ testCases.foreach(t => {
+ val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ")
+ var query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'),
$arrCollated)"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ // Implicit casting
+ val arr = t.a.map(s => s"'$s'").mkString(", ")
+ query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'), $arr)"
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ query = s"SELECT concat_ws('${t.s}', $arrCollated)"
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ })
+ // Unsupported collations
+ case class ConcatWsTestFail(s: String, a: Array[String], c: String)
+ val failCases = Seq(
+ ConcatWsTestFail(" ", Array("ABC", "%b%"), "UTF8_BINARY_LCASE"),
+ ConcatWsTestFail(" ", Array("ABC", "%B%"), "UNICODE"),
+ ConcatWsTestFail(" ", Array("ABC", "%b%"), "UNICODE_CI")
)
+ failCases.foreach(t => {
+ val arrCollated = t.a.map(s => s"collate('$s', '${t.c}')").mkString(", ")
+ val query = s"SELECT concat_ws(collate('${t.s}', '${t.c}'),
$arrCollated)"
+ val unsupportedCollation = intercept[AnalysisException] { sql(query) }
+ assert(unsupportedCollation.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE")
+ })
+ // Collation mismatch
+ val collationMismatch = intercept[AnalysisException] {
+ sql("SELECT concat_ws(' ',collate('Spark',
'UTF8_BINARY_LCASE'),collate('SQL', 'UNICODE'))")
+ }
+ assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+ }
- // Unsupported Collations
- val fails = Seq(
- CollationTestFail("ABC", "%b%", "UTF8_BINARY_LCASE"),
- CollationTestFail("ABC", "%B%", "UNICODE"),
- CollationTestFail("ABC", "%b%", "UNICODE_CI")
- )
- fails.foreach(ct =>
- assert(prepareConcatWs(" ", ct.collation, ct.s1, ct.s2)
- .checkInputDataTypes() ==
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> "first",
- "requiredType" -> """"STRING"""",
- "inputSql" -> s""""' ' collate ${ct.collation}"""",
- "inputType" -> s""""STRING COLLATE ${ct.collation}""""
- )
- )
- )
+ test("Support Contains string expression with collation") {
+ // Supported collations
+ case class ContainsTestCase[R](l: String, r: String, c: String, result: R)
+ val testCases = Seq(
+ ContainsTestCase("", "", "UTF8_BINARY", true),
+ ContainsTestCase("abcde", "C", "UNICODE", false),
+ ContainsTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false),
+ ContainsTestCase("abcde", "BCD", "UNICODE_CI", true)
)
+ testCases.foreach(t => {
+ val query = s"SELECT
contains(collate('${t.l}','${t.c}'),collate('${t.r}','${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // Implicit casting
+ checkAnswer(sql(s"SELECT
contains(collate('${t.l}','${t.c}'),'${t.r}')"), Row(t.result))
+ checkAnswer(sql(s"SELECT
contains('${t.l}',collate('${t.r}','${t.c}'))"), Row(t.result))
+ })
+ // Collation mismatch
+ val collationMismatch = intercept[AnalysisException] {
+ sql("SELECT
contains(collate('abcde','UTF8_BINARY_LCASE'),collate('C','UNICODE_CI'))")
+ }
+ assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}
- test("REPEAT check output type on explicitly collated string") {
- def testRepeat(expected: String, collationId: Int, input: String, n: Int):
Unit = {
- val s = Literal.create(input, StringType(collationId))
+ test("Support StartsWith string expression with collation") {
+ // Supported collations
+ case class StartsWithTestCase[R](l: String, r: String, c: String, result:
R)
+ val testCases = Seq(
+ StartsWithTestCase("", "", "UTF8_BINARY", true),
+ StartsWithTestCase("abcde", "A", "UNICODE", false),
+ StartsWithTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false),
+ StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT
startswith(collate('${t.l}','${t.c}'),collate('${t.r}','${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // Implicit casting
+ checkAnswer(sql(s"SELECT startswith(collate('${t.l}',
'${t.c}'),'${t.r}')"), Row(t.result))
+ checkAnswer(sql(s"SELECT startswith('${t.l}', collate('${t.r}',
'${t.c}'))"), Row(t.result))
+ })
+ // Collation mismatch
+ val collationMismatch = intercept[AnalysisException] {
+ sql("SELECT startswith(collate('abcde',
'UTF8_BINARY_LCASE'),collate('C', 'UNICODE_CI'))")
+ }
+ assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+ }
- checkEvaluation(Collation(StringRepeat(s,
Literal.create(n))).replacement, expected)
+ test("Support EndsWith string expression with collation") {
+ // Supported collations
+ case class EndsWithTestCase[R](l: String, r: String, c: String, result: R)
+ val testCases = Seq(
+ EndsWithTestCase("", "", "UTF8_BINARY", true),
+ EndsWithTestCase("abcde", "E", "UNICODE", false),
+ EndsWithTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false),
+ EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true)
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT endswith(collate('${t.l}', '${t.c}'),
collate('${t.r}', '${t.c}'))"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(BooleanType))
+ // Implicit casting
+ checkAnswer(sql(s"SELECT endswith(collate('${t.l}',
'${t.c}'),'${t.r}')"), Row(t.result))
+ checkAnswer(sql(s"SELECT endswith('${t.l}', collate('${t.r}',
'${t.c}'))"), Row(t.result))
+ })
+ // Collation mismatch
+ val collationMismatch = intercept[AnalysisException] {
+ sql("SELECT endswith(collate('abcde', 'UTF8_BINARY_LCASE'),collate('C',
'UNICODE_CI'))")
}
+ assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
+ }
- testRepeat("UTF8_BINARY", 0, "abc", 2)
- testRepeat("UTF8_BINARY_LCASE", 1, "abc", 2)
- testRepeat("UNICODE", 2, "abc", 2)
- testRepeat("UNICODE_CI", 3, "abc", 2)
+ test("Support StringRepeat string expression with collation") {
+ // Supported collations
+ case class StringRepeatTestCase[R](s: String, n: Int, c: String, result: R)
+ val testCases = Seq(
+ StringRepeatTestCase("", 1, "UTF8_BINARY", ""),
+ StringRepeatTestCase("a", 0, "UNICODE", ""),
+ StringRepeatTestCase("XY", 3, "UTF8_BINARY_LCASE", "XYXYXY"),
+ StringRepeatTestCase("123", 2, "UNICODE_CI", "123123")
+ )
+ testCases.foreach(t => {
+ val query = s"SELECT repeat(collate('${t.s}', '${t.c}'), ${t.n})"
+ // Result & data type
+ checkAnswer(sql(query), Row(t.result))
+ assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+ })
}
// TODO: Add more tests for other string expressions
}
-class CollationStringExpressionsANSISuite extends
CollationRegexpExpressionsSuite {
+class CollationStringExpressionsANSISuite extends
CollationStringExpressionsSuite {
override protected def sparkConf: SparkConf =
super.sparkConf.set(SQLConf.ANSI_ENABLED, true)
+
+ // TODO: If needed, add more tests for other string expressions (with ANSI
mode enabled)
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index c0322387c804..c4ddd25c99b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -271,90 +271,6 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
)
}
- case class CollationTestCase[R](left: String, right: String, collation:
String, expectedResult: R)
-
- test("Support contains string expression with Collation") {
- // Supported collations
- val checks = Seq(
- CollationTestCase("", "", "UTF8_BINARY", true),
- CollationTestCase("c", "", "UTF8_BINARY", true),
- CollationTestCase("", "c", "UTF8_BINARY", false),
- CollationTestCase("abcde", "c", "UTF8_BINARY", true),
- CollationTestCase("abcde", "C", "UTF8_BINARY", false),
- CollationTestCase("abcde", "bcd", "UTF8_BINARY", true),
- CollationTestCase("abcde", "BCD", "UTF8_BINARY", false),
- CollationTestCase("abcde", "fgh", "UTF8_BINARY", false),
- CollationTestCase("abcde", "FGH", "UTF8_BINARY", false),
- CollationTestCase("", "", "UNICODE", true),
- CollationTestCase("c", "", "UNICODE", true),
- CollationTestCase("", "c", "UNICODE", false),
- CollationTestCase("abcde", "c", "UNICODE", true),
- CollationTestCase("abcde", "C", "UNICODE", false),
- CollationTestCase("abcde", "bcd", "UNICODE", true),
- CollationTestCase("abcde", "BCD", "UNICODE", false),
- CollationTestCase("abcde", "fgh", "UNICODE", false),
- CollationTestCase("abcde", "FGH", "UNICODE", false),
- CollationTestCase("", "", "UTF8_BINARY_LCASE", true),
- CollationTestCase("c", "", "UTF8_BINARY_LCASE", true),
- CollationTestCase("", "c", "UTF8_BINARY_LCASE", false),
- CollationTestCase("abcde", "c", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "C", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "BCD", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "fgh", "UTF8_BINARY_LCASE", false),
- CollationTestCase("abcde", "FGH", "UTF8_BINARY_LCASE", false),
- CollationTestCase("", "", "UNICODE_CI", true),
- CollationTestCase("c", "", "UNICODE_CI", true),
- CollationTestCase("", "c", "UNICODE_CI", false),
- CollationTestCase("abcde", "c", "UNICODE_CI", true),
- CollationTestCase("abcde", "C", "UNICODE_CI", true),
- CollationTestCase("abcde", "bcd", "UNICODE_CI", true),
- CollationTestCase("abcde", "BCD", "UNICODE_CI", true),
- CollationTestCase("abcde", "fgh", "UNICODE_CI", false),
- CollationTestCase("abcde", "FGH", "UNICODE_CI", false)
- )
- checks.foreach(testCase => {
- checkAnswer(sql(s"SELECT contains(collate('${testCase.left}',
'${testCase.collation}')," +
- s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
- })
- }
-
- test("Support startsWith string expression with Collation") {
- // Supported collations
- val checks = Seq(
- CollationTestCase("abcde", "abc", "UTF8_BINARY", true),
- CollationTestCase("abcde", "ABC", "UTF8_BINARY", false),
- CollationTestCase("abcde", "abc", "UNICODE", true),
- CollationTestCase("abcde", "ABC", "UNICODE", false),
- CollationTestCase("abcde", "ABC", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false),
- CollationTestCase("abcde", "ABC", "UNICODE_CI", true),
- CollationTestCase("abcde", "bcd", "UNICODE_CI", false)
- )
- checks.foreach(testCase => {
- checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}',
'${testCase.collation}')," +
- s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
- })
- }
-
- test("Support endsWith string expression with Collation") {
- // Supported collations
- val checks = Seq(
- CollationTestCase("abcde", "cde", "UTF8_BINARY", true),
- CollationTestCase("abcde", "CDE", "UTF8_BINARY", false),
- CollationTestCase("abcde", "cde", "UNICODE", true),
- CollationTestCase("abcde", "CDE", "UNICODE", false),
- CollationTestCase("abcde", "CDE", "UTF8_BINARY_LCASE", true),
- CollationTestCase("abcde", "bcd", "UTF8_BINARY_LCASE", false),
- CollationTestCase("abcde", "CDE", "UNICODE_CI", true),
- CollationTestCase("abcde", "bcd", "UNICODE_CI", false)
- )
- checks.foreach(testCase => {
- checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}',
'${testCase.collation}')," +
- s"collate('${testCase.right}', '${testCase.collation}'))"),
Row(testCase.expectedResult))
- })
- }
-
test("aggregates count respects collation") {
Seq(
("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]