This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d8aff6e04941 [SPARK-48937][SQL] Add collation support for StringToMap 
string expressions
d8aff6e04941 is described below

commit d8aff6e0494198c48962fbe6e6b93edff2855067
Author: Uros Bojanic <[email protected]>
AuthorDate: Fri Aug 9 22:00:10 2024 +0800

    [SPARK-48937][SQL] Add collation support for StringToMap string expressions
    
    ### What changes were proposed in this pull request?
    Add collation awareness for `StringToMap` string expression.
    
    ### Why are the changes needed?
    `StringToMap` should be collation aware when splitting strings on specified 
delimiters.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, `StringToMap` is now collation aware.
    
    ### How was this patch tested?
    New unit tests and e2e sql tests for `str_to_map`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47621 from uros-db/fix-str-to-map.
    
    Authored-by: Uros Bojanic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/util/CollationAwareUTF8String.java    | 57 ++++++++++++++++++++++
 .../spark/sql/catalyst/util/CollationSupport.java  | 30 ++----------
 .../expressions/codegen/CodeGenerator.scala        |  3 +-
 .../catalyst/expressions/complexTypeCreator.scala  | 13 +++--
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 40 +++++++++------
 5 files changed, 96 insertions(+), 47 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 501d173fc485..b57f172428ac 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -32,10 +32,13 @@ import static 
org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;
 
 import java.text.CharacterIterator;
 import java.text.StringCharacterIterator;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
+import java.util.regex.Pattern;
 
 /**
  * Utility class for collation-aware UTF8String operations.
@@ -1226,6 +1229,60 @@ public class CollationAwareUTF8String {
     return UTF8String.fromString(src.substring(0, charIndex));
   }
 
+  public static UTF8String[] splitSQL(final UTF8String input, final UTF8String 
delim,
+      final int limit, final int collationId) {
+    if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
+      return input.split(delim, limit);
+    } else if 
(CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
+      return lowercaseSplitSQL(input, delim, limit);
+    } else {
+      return icuSplitSQL(input, delim, limit, collationId);
+    }
+  }
+
+  public static UTF8String[] lowercaseSplitSQL(final UTF8String string, final 
UTF8String delimiter,
+      final int limit) {
+      if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+      if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
+      Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
+        CollationSupport.lowercaseRegexFlags);
+      String[] splits = pattern.split(string.toString(), limit);
+      UTF8String[] res = new UTF8String[splits.length];
+      for (int i = 0; i < res.length; i++) {
+        res[i] = UTF8String.fromString(splits[i]);
+      }
+      return res;
+  }
+
+  public static UTF8String[] icuSplitSQL(final UTF8String string, final 
UTF8String delimiter,
+      final int limit, final int collationId) {
+    if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+    if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
+    List<UTF8String> strings = new ArrayList<>();
+    String target = string.toString(), pattern = delimiter.toString();
+    StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+    int start = 0, end;
+    while ((end = stringSearch.next()) != StringSearch.DONE) {
+      if (limit > 0 && strings.size() == limit - 1) {
+        break;
+      }
+      strings.add(UTF8String.fromString(target.substring(start, end)));
+      start = end + stringSearch.getMatchLength();
+    }
+    if (start <= target.length()) {
+      strings.add(UTF8String.fromString(target.substring(start)));
+    }
+    if (limit == 0) {
+      // Remove trailing empty strings
+      int i = strings.size() - 1;
+      while (i >= 0 && strings.get(i).numBytes() == 0) {
+        strings.remove(i);
+        i--;
+      }
+    }
+    return strings.toArray(new UTF8String[0]);
+  }
+
   // TODO: Add more collation-aware UTF8String operations here.
 
 }
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index f160661af389..651683796877 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -20,8 +20,6 @@ import com.ibm.icu.text.StringSearch;
 
 import org.apache.spark.unsafe.types.UTF8String;
 
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Map;
 import java.util.regex.Pattern;
 
@@ -62,33 +60,11 @@ public final class CollationSupport {
       return string.splitSQL(delimiter, -1);
     }
     public static UTF8String[] execLowercase(final UTF8String string, final 
UTF8String delimiter) {
-      if (delimiter.numBytes() == 0) return new UTF8String[] { string };
-      if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
-      Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
-        CollationSupport.lowercaseRegexFlags);
-      String[] splits = pattern.split(string.toString(), -1);
-      UTF8String[] res = new UTF8String[splits.length];
-      for (int i = 0; i < res.length; i++) {
-        res[i] = UTF8String.fromString(splits[i]);
-      }
-      return res;
+      return CollationAwareUTF8String.lowercaseSplitSQL(string, delimiter, -1);
     }
     public static UTF8String[] execICU(final UTF8String string, final 
UTF8String delimiter,
         final int collationId) {
-      if (delimiter.numBytes() == 0) return new UTF8String[] { string };
-      if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
-      List<UTF8String> strings = new ArrayList<>();
-      String target = string.toString(), pattern = delimiter.toString();
-      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
-      int start = 0, end;
-      while ((end = stringSearch.next()) != StringSearch.DONE) {
-        strings.add(UTF8String.fromString(target.substring(start, end)));
-        start = end + stringSearch.getMatchLength();
-      }
-      if (start <= target.length()) {
-        strings.add(UTF8String.fromString(target.substring(start)));
-      }
-      return strings.toArray(new UTF8String[0]);
+      return CollationAwareUTF8String.icuSplitSQL(string, delimiter, -1, 
collationId);
     }
   }
 
@@ -696,7 +672,7 @@ public final class CollationSupport {
     return 
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality;
   }
 
-  private static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | 
Pattern.CASE_INSENSITIVE;
+  static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | 
Pattern.CASE_INSENSITIVE;
   public static int collationAwareRegexFlags(final int collationId) {
     return supportsLowercaseRegex(collationId) ? lowercaseRegexFlags : 0;
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a39c10866984..30c00f5bf96b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -40,7 +40,7 @@ import 
org.apache.spark.sql.catalyst.encoders.HashableWeakReference
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, 
CollationAwareUTF8String, CollationFactory, CollationSupport, MapData, 
SQLOrderingUtil, UnsafeRowUtils}
 import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -1529,6 +1529,7 @@ object CodeGenerator extends Logging {
       classOf[TaskContext].getName,
       classOf[TaskKilledException].getName,
       classOf[InputMetrics].getName,
+      classOf[CollationAwareUTF8String].getName,
       classOf[CollationFactory].getName,
       classOf[CollationSupport].getName,
       QueryExecutionErrors.getClass.getName.stripSuffix("$")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1bfa11d67af6..ba1beab28d9a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -585,17 +585,20 @@ case class StringToMap(text: Expression, pairDelim: 
Expression, keyValueDelim: E
 
   private lazy val mapBuilder = new ArrayBasedMapBuilder(first.dataType, 
first.dataType)
 
+  private final lazy val collationId: Int = 
text.dataType.asInstanceOf[StringType].collationId
+
   override def nullSafeEval(
       inputString: Any,
       stringDelimiter: Any,
       keyValueDelimiter: Any): Any = {
-    val keyValues =
-      
inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String],
 -1)
+    val keyValues = 
CollationAwareUTF8String.splitSQL(inputString.asInstanceOf[UTF8String],
+      stringDelimiter.asInstanceOf[UTF8String], -1, collationId)
     val keyValueDelimiterUTF8String = 
keyValueDelimiter.asInstanceOf[UTF8String]
 
     var i = 0
     while (i < keyValues.length) {
-      val keyValueArray = keyValues(i).split(keyValueDelimiterUTF8String, 2)
+      val keyValueArray = CollationAwareUTF8String.splitSQL(
+        keyValues(i), keyValueDelimiterUTF8String, 2, collationId)
       val key = keyValueArray(0)
       val value = if (keyValueArray.length < 2) null else keyValueArray(1)
       mapBuilder.put(key, value)
@@ -610,9 +613,9 @@ case class StringToMap(text: Expression, pairDelim: 
Expression, keyValueDelim: E
 
     nullSafeCodeGen(ctx, ev, (text, pd, kvd) =>
       s"""
-         |UTF8String[] $keyValues = $text.split($pd, -1);
+         |UTF8String[] $keyValues = CollationAwareUTF8String.splitSQL($text, 
$pd, -1, $collationId);
          |for(UTF8String kvEntry: $keyValues) {
-         |  UTF8String[] kv = kvEntry.split($kvd, 2);
+         |  UTF8String[] kv = CollationAwareUTF8String.splitSQL(kvEntry, $kvd, 
2, $collationId);
          |  $builderTerm.put(kv[0], kv.length == 2 ? kv[1] : null);
          |}
          |${ev.value} = $builderTerm.build();
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 0fefcb79e2e3..7d0f6c401c0d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable.Seq
 
 import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
 import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.test.SharedSparkSession
@@ -34,8 +34,9 @@ import org.apache.spark.util.collection.OpenHashMap
 
 // scalastyle:off nonascii
 class CollationSQLExpressionsSuite
-  extends QueryTest
-  with SharedSparkSession {
+    extends QueryTest
+    with SharedSparkSession
+    with ExpressionEvalHelper {
 
   private val testSuppCollations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", 
"UNICODE_CI")
 
@@ -964,25 +965,36 @@ class CollationSQLExpressionsSuite
     })
   }
 
-  test("Support StringToMap expression with collation") {
-    // Supported collations
-    case class StringToMapTestCase[R](t: String, p: String, k: String, c: 
String, result: R)
+  test("Support `StringToMap` expression with collation") {
+    case class StringToMapTestCase[R](
+        text: String,
+        pairDelim: String,
+        keyValueDelim: String,
+        collation: String,
+        result: R)
     val testCases = Seq(
       StringToMapTestCase("a:1,b:2,c:3", ",", ":", "UTF8_BINARY",
         Map("a" -> "1", "b" -> "2", "c" -> "3")),
-      StringToMapTestCase("A-1;B-2;C-3", ";", "-", "UTF8_LCASE",
+      StringToMapTestCase("A-1xB-2xC-3", "X", "-", "UTF8_LCASE",
         Map("A" -> "1", "B" -> "2", "C" -> "3")),
-      StringToMapTestCase("1:a,2:b,3:c", ",", ":", "UNICODE",
+      StringToMapTestCase("1:ax2:bx3:c", "x", ":", "UNICODE",
         Map("1" -> "a", "2" -> "b", "3" -> "c")),
-      StringToMapTestCase("1/A!2/B!3/C", "!", "/", "UNICODE_CI",
+      StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
         Map("1" -> "A", "2" -> "B", "3" -> "C"))
     )
     testCases.foreach(t => {
-      val query = s"SELECT str_to_map(collate('${t.t}', '${t.c}'), '${t.p}', 
'${t.k}');"
-      // Result & data type
-      checkAnswer(sql(query), Row(t.result))
-      val dataType = MapType(StringType(t.c), StringType(t.c), true)
-      assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+      // Unit test.
+      val text = Literal.create(t.text, StringType(t.collation))
+      val pairDelim = Literal.create(t.pairDelim, StringType(t.collation))
+      val keyValueDelim = Literal.create(t.keyValueDelim, 
StringType(t.collation))
+      checkEvaluation(StringToMap(text, pairDelim, keyValueDelim), t.result)
+      // E2E SQL test.
+      withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
+        val query = s"SELECT str_to_map('${t.text}', '${t.pairDelim}', 
'${t.keyValueDelim}')"
+        checkAnswer(sql(query), Row(t.result))
+        val dataType = MapType(StringType(t.collation), 
StringType(t.collation), true)
+        assert(sql(query).schema.fields.head.dataType.sameType(dataType))
+      }
     })
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to