This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new b06f0f18257 [SPARK-43802][SQL][3.4] Fix codegen for unhex and unbase64 
with failOnError=true
b06f0f18257 is described below

commit b06f0f18257a11cf9d66d32b59ecd3b49657dbe9
Author: Adam Binford <[email protected]>
AuthorDate: Fri May 26 19:30:14 2023 -0700

    [SPARK-43802][SQL][3.4] Fix codegen for unhex and unbase64 with 
failOnError=true
    
    ### What changes were proposed in this pull request?
    
    This is a backport of https://github.com/apache/spark/pull/41317.
    
    Fixes an error with codegen for unhex and unbase64 expression when 
failOnError is enabled introduced in https://github.com/apache/spark/pull/37483.
    
    ### Why are the changes needed?
    
    Codegen fails and Spark falls back to interpreted evaluation:
    ```
    Caused by: org.codehaus.commons.compiler.CompileException: File 
'generated.java', Line 47, Column 1: failed to compile: 
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47, 
Column 1: Unknown variable or type "BASE64"
    ```
    in the code block:
    ```
    /* 107 */         if 
(!org.apache.spark.sql.catalyst.expressions.UnBase64.isValidBase64(project_value_1))
 {
    /* 108 */           throw 
QueryExecutionErrors.invalidInputInConversionError(
    /* 109 */             ((org.apache.spark.sql.types.BinaryType$) 
references[1] /* to */),
    /* 110 */             project_value_1,
    /* 111 */             BASE64,
    /* 112 */             "try_to_binary");
    /* 113 */         }
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Bug fix.
    
    ### How was this patch tested?
    
    Added to the existing tests so evaluate an expression with failOnError 
enabled to test that path of the codegen.
    
    Closes #41334 from Kimahriman/to-binary-codegen-backport.
    
    Authored-by: Adam Binford <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/catalyst/expressions/mathExpressions.scala |  3 +-
 .../catalyst/expressions/stringExpressions.scala   |  3 +-
 .../expressions/MathExpressionsSuite.scala         |  3 ++
 .../expressions/StringExpressionsSuite.scala       |  4 +-
 .../sql/errors/QueryExecutionErrorsSuite.scala     | 46 ++++++++++++++++------
 5 files changed, 43 insertions(+), 16 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index dcc821a24ea..add59a38b72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1172,14 +1172,13 @@ case class Unhex(child: Expression, failOnError: 
Boolean = false)
     nullSafeCodeGen(ctx, ev, c => {
       val hex = Hex.getClass.getName.stripSuffix("$")
       val maybeFailOnErrorCode = if (failOnError) {
-        val format = UTF8String.fromString("BASE64");
         val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
         s"""
            |if (${ev.value} == null) {
            |  throw QueryExecutionErrors.invalidInputInConversionError(
            |    $binaryType,
            |    $c,
-           |    $format,
+           |    UTF8String.fromString("HEX"),
            |    "try_to_binary");
            |}
            |""".stripMargin
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index c1ca86b356e..1e58384c81d 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -2356,14 +2356,13 @@ case class UnBase64(child: Expression, failOnError: 
Boolean = false)
     nullSafeCodeGen(ctx, ev, child => {
       val maybeValidateInputCode = if (failOnError) {
         val unbase64 = UnBase64.getClass.getName.stripSuffix("$")
-        val format = UTF8String.fromString("BASE64");
         val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
         s"""
            |if (!$unbase64.isValidBase64($child)) {
            |  throw QueryExecutionErrors.invalidInputInConversionError(
            |    $binaryType,
            |    $child,
-           |    $format,
+           |    UTF8String.fromString("BASE64"),
            |    "try_to_binary");
            |}
        """.stripMargin
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 437f7ddee01..823a6d2ce86 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -615,6 +615,9 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Unhex(Literal("GG")), null)
     checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
     checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+
+    // failOnError
+    checkEvaluation(Unhex(Literal("12345"), true), Array[Byte](1, 35, 69))
     // scalastyle:off
     // Turn off scala style for non-ascii chars
     checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), 
"δΈ‰ι‡ηš„".getBytes(StandardCharsets.UTF_8))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 017f4483e88..399aedd7b71 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -468,7 +468,9 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", 
create_row("abdef"))
     checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
     checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, 
create_row("abdef"))
-    checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))
+
+    // failOnError
+    checkEvaluation(Base64(UnBase64(a, true)), "AQIDBA==", 
create_row("AQIDBA=="))
 
     checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
     checkEvaluation(Base64(b), "", create_row(Array.empty[Byte]))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 865d6735cf6..27dbe45952e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -28,6 +28,7 @@ import org.mockito.Mockito.{mock, spy, when}
 
 import org.apache.spark._
 import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, 
Row, SaveMode}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.util.BadRecordException
 import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, 
JDBCOptions}
 import 
org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
@@ -52,17 +53,40 @@ class QueryExecutionErrorsSuite
 
   import testImplicits._
 
-  test("CONVERSION_INVALID_INPUT: to_binary conversion function") {
-    checkError(
-      exception = intercept[SparkIllegalArgumentException] {
-        sql("select to_binary('???', 'base64')").collect()
-      },
-      errorClass = "CONVERSION_INVALID_INPUT",
-      parameters = Map(
-        "str" -> "'???'",
-        "fmt" -> "'BASE64'",
-        "targetType" -> "\"BINARY\"",
-        "suggestion" -> "`try_to_binary`"))
+  test("CONVERSION_INVALID_INPUT: to_binary conversion function base64") {
+    for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+        val exception = intercept[SparkException] {
+          Seq(("???")).toDF("a").selectExpr("to_binary(a, 'base64')").collect()
+        }.getCause.asInstanceOf[SparkIllegalArgumentException]
+        checkError(
+          exception,
+          errorClass = "CONVERSION_INVALID_INPUT",
+          parameters = Map(
+            "str" -> "'???'",
+            "fmt" -> "'BASE64'",
+            "targetType" -> "\"BINARY\"",
+            "suggestion" -> "`try_to_binary`"))
+      }
+    }
+  }
+
+  test("CONVERSION_INVALID_INPUT: to_binary conversion function hex") {
+    for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+        val exception = intercept[SparkException] {
+          Seq(("???")).toDF("a").selectExpr("to_binary(a, 'hex')").collect()
+        }.getCause.asInstanceOf[SparkIllegalArgumentException]
+        checkError(
+          exception,
+          errorClass = "CONVERSION_INVALID_INPUT",
+          parameters = Map(
+            "str" -> "'???'",
+            "fmt" -> "'HEX'",
+            "targetType" -> "\"BINARY\"",
+            "suggestion" -> "`try_to_binary`"))
+      }
+    }
   }
 
   private def getAesInputs(): (DataFrame, DataFrame) = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to