This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new b06f0f18257 [SPARK-43802][SQL][3.4] Fix codegen for unhex and unbase64
with failOnError=true
b06f0f18257 is described below
commit b06f0f18257a11cf9d66d32b59ecd3b49657dbe9
Author: Adam Binford <[email protected]>
AuthorDate: Fri May 26 19:30:14 2023 -0700
[SPARK-43802][SQL][3.4] Fix codegen for unhex and unbase64 with
failOnError=true
### What changes were proposed in this pull request?
This is a backport of https://github.com/apache/spark/pull/41317.
Fixes an error with codegen for unhex and unbase64 expression when
failOnError is enabled introduced in https://github.com/apache/spark/pull/37483.
### Why are the changes needed?
Codegen fails and Spark falls back to interpreted evaluation:
```
Caused by: org.codehaus.commons.compiler.CompileException: File
'generated.java', Line 47, Column 1: failed to compile:
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47,
Column 1: Unknown variable or type "BASE64"
```
in the code block:
```
/* 107 */ if
(!org.apache.spark.sql.catalyst.expressions.UnBase64.isValidBase64(project_value_1))
{
/* 108 */ throw
QueryExecutionErrors.invalidInputInConversionError(
/* 109 */ ((org.apache.spark.sql.types.BinaryType$)
references[1] /* to */),
/* 110 */ project_value_1,
/* 111 */ BASE64,
/* 112 */ "try_to_binary");
/* 113 */ }
```
### Does this PR introduce _any_ user-facing change?
Bug fix.
### How was this patch tested?
Added to the existing tests so evaluate an expression with failOnError
enabled to test that path of the codegen.
Closes #41334 from Kimahriman/to-binary-codegen-backport.
Authored-by: Adam Binford <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/catalyst/expressions/mathExpressions.scala | 3 +-
.../catalyst/expressions/stringExpressions.scala | 3 +-
.../expressions/MathExpressionsSuite.scala | 3 ++
.../expressions/StringExpressionsSuite.scala | 4 +-
.../sql/errors/QueryExecutionErrorsSuite.scala | 46 ++++++++++++++++------
5 files changed, 43 insertions(+), 16 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index dcc821a24ea..add59a38b72 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1172,14 +1172,13 @@ case class Unhex(child: Expression, failOnError:
Boolean = false)
nullSafeCodeGen(ctx, ev, c => {
val hex = Hex.getClass.getName.stripSuffix("$")
val maybeFailOnErrorCode = if (failOnError) {
- val format = UTF8String.fromString("BASE64");
val binaryType = ctx.addReferenceObj("to", BinaryType,
BinaryType.getClass.getName)
s"""
|if (${ev.value} == null) {
| throw QueryExecutionErrors.invalidInputInConversionError(
| $binaryType,
| $c,
- | $format,
+ | UTF8String.fromString("HEX"),
| "try_to_binary");
|}
|""".stripMargin
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index c1ca86b356e..1e58384c81d 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -2356,14 +2356,13 @@ case class UnBase64(child: Expression, failOnError:
Boolean = false)
nullSafeCodeGen(ctx, ev, child => {
val maybeValidateInputCode = if (failOnError) {
val unbase64 = UnBase64.getClass.getName.stripSuffix("$")
- val format = UTF8String.fromString("BASE64");
val binaryType = ctx.addReferenceObj("to", BinaryType,
BinaryType.getClass.getName)
s"""
|if (!$unbase64.isValidBase64($child)) {
| throw QueryExecutionErrors.invalidInputInConversionError(
| $binaryType,
| $child,
- | $format,
+ | UTF8String.fromString("BASE64"),
| "try_to_binary");
|}
""".stripMargin
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 437f7ddee01..823a6d2ce86 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -615,6 +615,9 @@ class MathExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(Unhex(Literal("GG")), null)
checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+
+ // failOnError
+ checkEvaluation(Unhex(Literal("12345"), true), Array[Byte](1, 35, 69))
// scalastyle:off
// Turn off scala style for non-ascii chars
checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")),
"δΈιη".getBytes(StandardCharsets.UTF_8))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 017f4483e88..399aedd7b71 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -468,7 +468,9 @@ class StringExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==",
create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null,
create_row("abdef"))
- checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))
+
+ // failOnError
+ checkEvaluation(Base64(UnBase64(a, true)), "AQIDBA==",
create_row("AQIDBA=="))
checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
checkEvaluation(Base64(b), "", create_row(Array.empty[Byte]))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 865d6735cf6..27dbe45952e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -28,6 +28,7 @@ import org.mockito.Mockito.{mock, spy, when}
import org.apache.spark._
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest,
Row, SaveMode}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.util.BadRecordException
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry,
JDBCOptions}
import
org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
@@ -52,17 +53,40 @@ class QueryExecutionErrorsSuite
import testImplicits._
- test("CONVERSION_INVALID_INPUT: to_binary conversion function") {
- checkError(
- exception = intercept[SparkIllegalArgumentException] {
- sql("select to_binary('???', 'base64')").collect()
- },
- errorClass = "CONVERSION_INVALID_INPUT",
- parameters = Map(
- "str" -> "'???'",
- "fmt" -> "'BASE64'",
- "targetType" -> "\"BINARY\"",
- "suggestion" -> "`try_to_binary`"))
+ test("CONVERSION_INVALID_INPUT: to_binary conversion function base64") {
+ for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+ val exception = intercept[SparkException] {
+ Seq(("???")).toDF("a").selectExpr("to_binary(a, 'base64')").collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException]
+ checkError(
+ exception,
+ errorClass = "CONVERSION_INVALID_INPUT",
+ parameters = Map(
+ "str" -> "'???'",
+ "fmt" -> "'BASE64'",
+ "targetType" -> "\"BINARY\"",
+ "suggestion" -> "`try_to_binary`"))
+ }
+ }
+ }
+
+ test("CONVERSION_INVALID_INPUT: to_binary conversion function hex") {
+ for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+ withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+ val exception = intercept[SparkException] {
+ Seq(("???")).toDF("a").selectExpr("to_binary(a, 'hex')").collect()
+ }.getCause.asInstanceOf[SparkIllegalArgumentException]
+ checkError(
+ exception,
+ errorClass = "CONVERSION_INVALID_INPUT",
+ parameters = Map(
+ "str" -> "'???'",
+ "fmt" -> "'HEX'",
+ "targetType" -> "\"BINARY\"",
+ "suggestion" -> "`try_to_binary`"))
+ }
+ }
}
private def getAesInputs(): (DataFrame, DataFrame) = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]