This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d8aff6e04941 [SPARK-48937][SQL] Add collation support for StringToMap
string expressions
d8aff6e04941 is described below
commit d8aff6e0494198c48962fbe6e6b93edff2855067
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Aug 9 22:00:10 2024 +0800
[SPARK-48937][SQL] Add collation support for StringToMap string expressions
### What changes were proposed in this pull request?
Add collation awareness for `StringToMap` string expression.
### Why are the changes needed?
`StringToMap` should be collation aware when splitting strings on specified
delimiters.
### Does this PR introduce _any_ user-facing change?
Yes, `StringToMap` is now collation aware.
### How was this patch tested?
New unit tests and e2e sql tests for `str_to_map`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47621 from uros-db/fix-str-to-map.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/util/CollationAwareUTF8String.java | 57 ++++++++++++++++++++++
.../spark/sql/catalyst/util/CollationSupport.java | 30 ++----------
.../expressions/codegen/CodeGenerator.scala | 3 +-
.../catalyst/expressions/complexTypeCreator.scala | 13 +++--
.../spark/sql/CollationSQLExpressionsSuite.scala | 40 +++++++++------
5 files changed, 96 insertions(+), 47 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 501d173fc485..b57f172428ac 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -32,10 +32,13 @@ import static
org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;
import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
import java.util.Map;
+import java.util.regex.Pattern;
/**
* Utility class for collation-aware UTF8String operations.
@@ -1226,6 +1229,60 @@ public class CollationAwareUTF8String {
return UTF8String.fromString(src.substring(0, charIndex));
}
+ public static UTF8String[] splitSQL(final UTF8String input, final UTF8String
delim,
+ final int limit, final int collationId) {
+ if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
+ return input.split(delim, limit);
+ } else if
(CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
+ return lowercaseSplitSQL(input, delim, limit);
+ } else {
+ return icuSplitSQL(input, delim, limit, collationId);
+ }
+ }
+
+ public static UTF8String[] lowercaseSplitSQL(final UTF8String string, final
UTF8String delimiter,
+ final int limit) {
+ if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+ if (string.numBytes() == 0) return new UTF8String[] {
UTF8String.EMPTY_UTF8 };
+ Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
+ CollationSupport.lowercaseRegexFlags);
+ String[] splits = pattern.split(string.toString(), limit);
+ UTF8String[] res = new UTF8String[splits.length];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = UTF8String.fromString(splits[i]);
+ }
+ return res;
+ }
+
+ public static UTF8String[] icuSplitSQL(final UTF8String string, final
UTF8String delimiter,
+ final int limit, final int collationId) {
+ if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+ if (string.numBytes() == 0) return new UTF8String[] {
UTF8String.EMPTY_UTF8 };
+ List<UTF8String> strings = new ArrayList<>();
+ String target = string.toString(), pattern = delimiter.toString();
+ StringSearch stringSearch = CollationFactory.getStringSearch(target,
pattern, collationId);
+ int start = 0, end;
+ while ((end = stringSearch.next()) != StringSearch.DONE) {
+ if (limit > 0 && strings.size() == limit - 1) {
+ break;
+ }
+ strings.add(UTF8String.fromString(target.substring(start, end)));
+ start = end + stringSearch.getMatchLength();
+ }
+ if (start <= target.length()) {
+ strings.add(UTF8String.fromString(target.substring(start)));
+ }
+ if (limit == 0) {
+ // Remove trailing empty strings
+ int i = strings.size() - 1;
+ while (i >= 0 && strings.get(i).numBytes() == 0) {
+ strings.remove(i);
+ i--;
+ }
+ }
+ return strings.toArray(new UTF8String[0]);
+ }
+
// TODO: Add more collation-aware UTF8String operations here.
}
diff --git
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index f160661af389..651683796877 100644
---
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -20,8 +20,6 @@ import com.ibm.icu.text.StringSearch;
import org.apache.spark.unsafe.types.UTF8String;
-import java.util.ArrayList;
-import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
@@ -62,33 +60,11 @@ public final class CollationSupport {
return string.splitSQL(delimiter, -1);
}
public static UTF8String[] execLowercase(final UTF8String string, final
UTF8String delimiter) {
- if (delimiter.numBytes() == 0) return new UTF8String[] { string };
- if (string.numBytes() == 0) return new UTF8String[] {
UTF8String.EMPTY_UTF8 };
- Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
- CollationSupport.lowercaseRegexFlags);
- String[] splits = pattern.split(string.toString(), -1);
- UTF8String[] res = new UTF8String[splits.length];
- for (int i = 0; i < res.length; i++) {
- res[i] = UTF8String.fromString(splits[i]);
- }
- return res;
+ return CollationAwareUTF8String.lowercaseSplitSQL(string, delimiter, -1);
}
public static UTF8String[] execICU(final UTF8String string, final
UTF8String delimiter,
final int collationId) {
- if (delimiter.numBytes() == 0) return new UTF8String[] { string };
- if (string.numBytes() == 0) return new UTF8String[] {
UTF8String.EMPTY_UTF8 };
- List<UTF8String> strings = new ArrayList<>();
- String target = string.toString(), pattern = delimiter.toString();
- StringSearch stringSearch = CollationFactory.getStringSearch(target,
pattern, collationId);
- int start = 0, end;
- while ((end = stringSearch.next()) != StringSearch.DONE) {
- strings.add(UTF8String.fromString(target.substring(start, end)));
- start = end + stringSearch.getMatchLength();
- }
- if (start <= target.length()) {
- strings.add(UTF8String.fromString(target.substring(start)));
- }
- return strings.toArray(new UTF8String[0]);
+ return CollationAwareUTF8String.icuSplitSQL(string, delimiter, -1,
collationId);
}
}
@@ -696,7 +672,7 @@ public final class CollationSupport {
return
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality;
}
- private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE |
Pattern.CASE_INSENSITIVE;
+ static final int lowercaseRegexFlags = Pattern.UNICODE_CASE |
Pattern.CASE_INSENSITIVE;
public static int collationAwareRegexFlags(final int collationId) {
return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0;
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a39c10866984..30c00f5bf96b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -40,7 +40,7 @@ import
org.apache.spark.sql.catalyst.encoders.HashableWeakReference
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData,
CollationAwareUTF8String, CollationFactory, CollationSupport, MapData,
SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -1529,6 +1529,7 @@ object CodeGenerator extends Logging {
classOf[TaskContext].getName,
classOf[TaskKilledException].getName,
classOf[InputMetrics].getName,
+ classOf[CollationAwareUTF8String].getName,
classOf[CollationFactory].getName,
classOf[CollationSupport].getName,
QueryExecutionErrors.getClass.getName.stripSuffix("$")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1bfa11d67af6..ba1beab28d9a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -585,17 +585,20 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
private lazy val mapBuilder = new ArrayBasedMapBuilder(first.dataType,
first.dataType)
+ private final lazy val collationId: Int =
text.dataType.asInstanceOf[StringType].collationId
+
override def nullSafeEval(
inputString: Any,
stringDelimiter: Any,
keyValueDelimiter: Any): Any = {
- val keyValues =
-
inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String],
-1)
+ val keyValues =
CollationAwareUTF8String.splitSQL(inputString.asInstanceOf[UTF8String],
+ stringDelimiter.asInstanceOf[UTF8String], -1, collationId)
val keyValueDelimiterUTF8String =
keyValueDelimiter.asInstanceOf[UTF8String]
var i = 0
while (i < keyValues.length) {
- val keyValueArray = keyValues(i).split(keyValueDelimiterUTF8String, 2)
+ val keyValueArray = CollationAwareUTF8String.splitSQL(
+ keyValues(i), keyValueDelimiterUTF8String, 2, collationId)
val key = keyValueArray(0)
val value = if (keyValueArray.length < 2) null else keyValueArray(1)
mapBuilder.put(key, value)
@@ -610,9 +613,9 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
nullSafeCodeGen(ctx, ev, (text, pd, kvd) =>
s"""
- |UTF8String[] $keyValues = $text.split($pd, -1);
+ |UTF8String[] $keyValues = CollationAwareUTF8String.splitSQL($text,
$pd, -1, $collationId);
|for(UTF8String kvEntry: $keyValues) {
- | UTF8String[] kv = kvEntry.split($kvd, 2);
+ | UTF8String[] kv = CollationAwareUTF8String.splitSQL(kvEntry, $kvd,
2, $collationId);
| $builderTerm.put(kv[0], kv.length == 2 ? kv[1] : null);
|}
|${ev.value} = $builderTerm.build();
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 0fefcb79e2e3..7d0f6c401c0d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable.Seq
import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
@@ -34,8 +34,9 @@ import org.apache.spark.util.collection.OpenHashMap
// scalastyle:off nonascii
class CollationSQLExpressionsSuite
- extends QueryTest
- with SharedSparkSession {
+ extends QueryTest
+ with SharedSparkSession
+ with ExpressionEvalHelper {
private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE",
"UNICODE_CI")
@@ -964,25 +965,36 @@ class CollationSQLExpressionsSuite
})
}
- test("Support StringToMap expression with collation") {
- // Supported collations
- case class StringToMapTestCase[R](t: String, p: String, k: String, c:
String, result: R)
+ test("Support `StringToMap` expression with collation") {
+ case class StringToMapTestCase[R](
+ text: String,
+ pairDelim: String,
+ keyValueDelim: String,
+ collation: String,
+ result: R)
val testCases = Seq(
StringToMapTestCase("a:1,b:2,c:3", ",", ":", "UTF8_BINARY",
Map("a" -> "1", "b" -> "2", "c" -> "3")),
- StringToMapTestCase("A-1;B-2;C-3", ";", "-", "UTF8_LCASE",
+ StringToMapTestCase("A-1xB-2xC-3", "X", "-", "UTF8_LCASE",
Map("A" -> "1", "B" -> "2", "C" -> "3")),
- StringToMapTestCase("1:a,2:b,3:c", ",", ":", "UNICODE",
+ StringToMapTestCase("1:ax2:bx3:c", "x", ":", "UNICODE",
Map("1" -> "a", "2" -> "b", "3" -> "c")),
- StringToMapTestCase("1/A!2/B!3/C", "!", "/", "UNICODE_CI",
+ StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
Map("1" -> "A", "2" -> "B", "3" -> "C"))
)
testCases.foreach(t => {
- val query = s"SELECT str_to_map(collate('${t.t}', '${t.c}'), '${t.p}',
'${t.k}');"
- // Result & data type
- checkAnswer(sql(query), Row(t.result))
- val dataType = MapType(StringType(t.c), StringType(t.c), true)
- assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+ // Unit test.
+ val text = Literal.create(t.text, StringType(t.collation))
+ val pairDelim = Literal.create(t.pairDelim, StringType(t.collation))
+ val keyValueDelim = Literal.create(t.keyValueDelim,
StringType(t.collation))
+ checkEvaluation(StringToMap(text, pairDelim, keyValueDelim), t.result)
+ // E2E SQL test.
+ withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
+ val query = s"SELECT str_to_map('${t.text}', '${t.pairDelim}',
'${t.keyValueDelim}')"
+ checkAnswer(sql(query), Row(t.result))
+ val dataType = MapType(StringType(t.collation),
StringType(t.collation), true)
+ assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+ }
})
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]