This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 484e7ac9acef [SPARK-48472][SQL] Enable reflect expressions with 
collated strings
484e7ac9acef is described below

commit 484e7ac9acefc46ba6f1bc3019251441d9bb1507
Author: Mihailo Aleksic <[email protected]>
AuthorDate: Wed Jun 19 16:23:47 2024 +0800

    [SPARK-48472][SQL] Enable reflect expressions with collated strings
    
    ### What changes were proposed in this pull request?
    
    Changes made in this pull request enable collation of strings in "reflect" 
expressions.
    
    ### Why are the changes needed?
    
    Changes are bug fix which enable users to use feature mentioned above.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Using unit test which can be found in 
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46991 from mihailoale-db/ReflectionCollationFix.
    
    Authored-by: Mihailo Aleksic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/CallMethodViaReflection.scala      | 23 ++++++----
 .../spark/sql/CollationSQLExpressionsSuite.scala   | 51 ++++++++++++++++++++++
 2 files changed, 66 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
index c42b54222f17..13ea8c77c41b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
@@ -26,6 +26,8 @@ import 
org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.ArrayImplicits._
@@ -77,12 +79,12 @@ case class CallMethodViaReflection(
       )
     } else {
       val unexpectedParameter = children.zipWithIndex.collectFirst {
-        case (e, 0) if !(e.dataType == StringType && e.foldable) =>
+        case (e, 0) if !(e.dataType.isInstanceOf[StringType] && e.foldable) =>
           DataTypeMismatch(
             errorSubClass = "NON_FOLDABLE_INPUT",
             messageParameters = Map(
               "inputName" -> toSQLId("class"),
-              "inputType" -> toSQLType(StringType),
+              "inputType" -> toSQLType(StringTypeAnyCollation),
               "inputExpr" -> toSQLExpr(children.head)
             )
           )
@@ -90,12 +92,12 @@ case class CallMethodViaReflection(
           DataTypeMismatch(
             errorSubClass = "UNEXPECTED_NULL",
             messageParameters = Map("exprName" -> toSQLId("class")))
-        case (e, 1) if !(e.dataType == StringType && e.foldable) =>
+        case (e, 1) if !(e.dataType.isInstanceOf[StringType] && e.foldable) =>
           DataTypeMismatch(
             errorSubClass = "NON_FOLDABLE_INPUT",
             messageParameters = Map(
               "inputName" -> toSQLId("method"),
-              "inputType" -> toSQLType(StringType),
+              "inputType" -> toSQLType(StringTypeAnyCollation),
               "inputExpr" -> toSQLExpr(children(1))
             )
           )
@@ -103,14 +105,16 @@ case class CallMethodViaReflection(
           DataTypeMismatch(
             errorSubClass = "UNEXPECTED_NULL",
             messageParameters = Map("exprName" -> toSQLId("method")))
-        case (e, idx) if idx > 1 && 
!CallMethodViaReflection.typeMapping.contains(e.dataType) =>
+        case (e, idx) if idx > 1 &&
+          (!CallMethodViaReflection.typeMapping.contains(e.dataType)
+            && !e.dataType.isInstanceOf[StringType]) =>
           DataTypeMismatch(
             errorSubClass = "UNEXPECTED_INPUT_TYPE",
             messageParameters = Map(
               "paramIndex" -> ordinalNumber(idx),
               "requiredType" -> toSQLType(
                 TypeCollection(BooleanType, ByteType, ShortType,
-                  IntegerType, LongType, FloatType, DoubleType, StringType)),
+                  IntegerType, LongType, FloatType, DoubleType, 
StringTypeAnyCollation)),
               "inputSql" -> toSQLExpr(e),
               "inputType" -> toSQLType(e.dataType))
           )
@@ -134,7 +138,7 @@ case class CallMethodViaReflection(
   }
 
   override def nullable: Boolean = true
-  override val dataType: DataType = StringType
+  override val dataType: DataType = SQLConf.get.defaultStringType
   override protected def initializeInternal(partitionIndex: Int): Unit = {}
 
   override protected def evalInternal(input: InternalRow): Any = {
@@ -230,7 +234,10 @@ object CallMethodViaReflection {
         // Argument type must match. That is, either the method's argument 
type matches one of the
         // acceptable types defined in typeMapping, or it is a super type of 
the acceptable types.
         candidateTypes.zip(argTypes).forall { case (candidateType, argType) =>
-          typeMapping(argType).exists(candidateType.isAssignableFrom)
+          if (!argType.isInstanceOf[StringType]) {
+            typeMapping(argType).exists(candidateType.isAssignableFrom)
+          }
+          else candidateType.isAssignableFrom(classOf[String])
         }
       }
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 0c54ccb7cfb1..0a7b513457a5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -23,6 +23,7 @@ import java.text.SimpleDateFormat
 import scala.collection.immutable.Seq
 
 import org.apache.spark.{SparkConf, SparkException, 
SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
@@ -2020,6 +2021,56 @@ class CollationSQLExpressionsSuite
     })
   }
 
+  test("Reflect expressions with collated strings") {
+    // be aware that output of java.util.UUID.fromString is always lowercase
+
+    case class ReflectExpressions(
+      left: String,
+      leftCollation: String,
+      right: String,
+      rightCollation: String,
+      result: Boolean
+    )
+
+    val testCases = Seq(
+      ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary",
+        "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", true),
+      ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary",
+        "A5Cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", false),
+
+      ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary",
+        "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_lcase", true),
+      ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary",
+        "A5Cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_lcase", true)
+    )
+    testCases.foreach(testCase => {
+      val query =
+        s"""
+           |SELECT REFLECT('java.util.UUID', 'fromString',
+           |collate('${testCase.left}', '${testCase.leftCollation}'))=
+           |collate('${testCase.right}', '${testCase.rightCollation}');
+           |""".stripMargin
+      val testQuery = sql(query)
+      checkAnswer(testQuery, Row(testCase.result))
+    })
+
+    val queryPass =
+      s"""
+         |SELECT REFLECT('java.lang.Integer', 'toHexString',2);
+         |""".stripMargin
+    val testQueryPass = sql(queryPass)
+    checkAnswer(testQueryPass, Row("2"))
+
+    val queryFail =
+      s"""
+         |SELECT REFLECT('java.lang.Integer', 'toHexString',"2");
+         |""".stripMargin
+    val typeException = intercept[ExtendedAnalysisException] {
+      sql(queryFail).collect()
+    }
+    assert(typeException.getErrorClass === 
"DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD")
+  }
+
   // TODO: Add more tests for other SQL expressions
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to