This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2e1c3dc8004b [SPARK-50087] Robust handling of boolean expressions in
CASE WHEN for MsSqlServer and future connectors
2e1c3dc8004b is described below
commit 2e1c3dc8004b4f003cde8dfae6857f5bef4bb170
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Nov 21 20:56:59 2024 +0800
[SPARK-50087] Robust handling of boolean expressions in CASE WHEN for
MsSqlServer and future connectors
### What changes were proposed in this pull request?
This PR proposes to propagate the `isPredicate` info in
`V2ExpressionBuilder` and wrap the children of CASE WHEN expression (only
`Predicate`s) with `IIF(<>, 1, 0)` for MsSqlServer. This is done to force
returning an int instead of a boolean, as SqlServer cannot handle boolean
expressions as a return type in CASE WHEN.
E.g.
```CASE WHEN ... ELSE a = b END```
Old behavior:
```CASE WHEN ... ELSE a = b END = 1```
New behavior:
Since in SqlServer a `= 1` is appended to the CASE WHEN, THEN and ELSE
blocks must return an int. Therefore the final expression becomes:
```CASE WHEN ... ELSE IIF(a = b, 1, 0) END = 1```
### Why are the changes needed?
A user cannot work with an MsSqlServer data with CASE WHEN clauses or IF
clauses if they wish to return a boolean value.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added tests to MsSqlServerIntegrationSuite
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48621 from andrej-db/SPARK-50087-CaseWhen.
Lead-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: andrej-db <[email protected]>
Co-authored-by: Andrej Gobeljić <[email protected]>
Co-authored-by: andrej-gobeljic_data <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 79 ++++++++++++++++++++++
.../sql/catalyst/util/V2ExpressionBuilder.scala | 6 +-
.../org/apache/spark/sql/jdbc/JdbcDialects.scala | 13 ++++
.../apache/spark/sql/jdbc/MsSqlServerDialect.scala | 24 +++++--
4 files changed, 114 insertions(+), 8 deletions(-)
diff --git
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index d884ad4c6246..fd7efb1efb76 100644
---
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2
import java.sql.Connection
import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker
import org.apache.spark.sql.types._
@@ -37,6 +41,17 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with
V2JDBCTest {
+ def getExternalEngineQuery(executedPlan: SparkPlan): String = {
+
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
+ }
+
+ def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = {
+ val queryNode = executedPlan.collect { case r: RowDataSourceScanExec =>
+ r
+ }.head
+ queryNode.rdd
+ }
+
override def excluded: Seq[String] = Seq(
"simple scan with OFFSET",
"simple scan with LIMIT and OFFSET",
@@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends
DockerJDBCIntegrationV2Suite with V2JD
|""".stripMargin)
assert(df.collect().length == 2)
}
+
+ test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") {
+ val df = sql(
+ s"""|SELECT * FROM $catalogName.employee
+ |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name =
'Wizard') END
+ |""".stripMargin
+ )
+
+ // scalastyle:off
+ assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE
WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <>
'Wizard'), 1, 0) END = 1) """
+ )
+ // scalastyle:on
+ df.collect()
+ }
+
+ test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true
test") {
+ val df = sql(
+ s"""|SELECT * FROM $catalogName.employee
+ |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1)
END
+ |""".stripMargin
+ )
+
+ // scalastyle:off
+ assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE
WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
+ )
+ // scalastyle:on
+ df.collect()
+ }
+
+ test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") {
+ val df = sql(
+ s"""|SELECT * FROM $catalogName.employee
+ |WHERE CASE WHEN (name = 'Legolas') THEN
+ | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name =
'Gandalf') END
+ | ELSE (name = 'Sauron') END
+ |""".stripMargin
+ )
+
+ // scalastyle:off
+ assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE
WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name"
= 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE
IIF(("name" = 'Sauron'), 1, 0) END = 1) """
+ )
+ // scalastyle:on
+ df.collect()
+ }
+
+ test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") {
+ val df = sql(
+ s"""|SELECT * FROM $catalogName.employee
+ |WHERE CASE WHEN (name = 'Legolas') THEN
+ | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END
+ | ELSE 'Sauron' END = name
+ |""".stripMargin
+ )
+
+ // scalastyle:off
+ assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+ """SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name"
IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf'
THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
+ )
+ // scalastyle:on
+ df.collect()
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 61a26d7a4fbd..b0ce2bb4293e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -221,8 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate:
Boolean = false) extends L
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
case caseWhen @ CaseWhen(branches, elseValue) =>
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
- val values = branches.map(_._2).flatMap(generateExpression(_))
- val elseExprOpt = elseValue.flatMap(generateExpression(_))
+ val values = branches.map(_._2).flatMap(generateExpression(_,
isPredicate))
+ val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
if (conditions.length == branches.length && values.length ==
branches.length &&
elseExprOpt.size == elseValue.size) {
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
@@ -421,7 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate:
Boolean = false) extends L
children: Seq[Expression],
dataType: DataType,
isPredicate: Boolean): Option[V2Expression] = {
- val childrenExpressions = children.flatMap(generateExpression(_))
+ val childrenExpressions = children.flatMap(generateExpression(_,
isPredicate))
if (childrenExpressions.length == children.length) {
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
Some(new V2Predicate(v2ExpressionName,
childrenExpressions.toArray[V2Expression]))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 3bf1390cb664..81ad1a6d38bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -42,6 +42,7 @@ import
org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, Literal,
NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry,
JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
@@ -377,6 +378,18 @@ abstract class JdbcDialect extends Serializable with
Logging {
}
private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
+ // Some dialects do not support boolean type and this convenient util
function is
+ // provided to generate SQL string without boolean values.
+ protected def inputToSQLNoBool(input: Expression): String = input match {
+ case p: Predicate if p.name() == "ALWAYS_TRUE" => "1"
+ case p: Predicate if p.name() == "ALWAYS_FALSE" => "0"
+ case p: Predicate => predicateToIntSQL(inputToSQL(p))
+ case _ => inputToSQL(input)
+ }
+
+ protected def predicateToIntSQL(input: String): String =
+ "CASE WHEN " + input + " THEN 1 ELSE 0 END"
+
override def visitLiteral(literal: Literal[_]): String = {
Option(literal.value()).map(v =>
compileValue(CatalystTypeConverters.convertToScala(v,
literal.dataType())).toString)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 7d476d43e5c7..7d339a90db8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -59,6 +59,8 @@ private case class MsSqlServerDialect() extends JdbcDialect
with NoLegacyJDBCErr
supportedFunctions.contains(funcName)
class MsSqlServerSQLBuilder extends JDBCSQLBuilder {
+ override protected def predicateToIntSQL(input: String): String =
+ "IIF(" + input + ", 1, 0)"
override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering:
NullOrdering): String = {
(sortDirection, nullOrdering) match {
@@ -87,12 +89,24 @@ private case class MsSqlServerDialect() extends JdbcDialect
with NoLegacyJDBCErr
expr match {
case e: Predicate => e.name() match {
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
- val Array(l, r) = e.children().map {
- case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0
END"
- case o => inputToSQL(o)
- }
+ val Array(l, r) = e.children().map(inputToSQLNoBool)
visitBinaryComparison(e.name(), l, r)
- case "CASE_WHEN" =>
visitCaseWhen(expressionsToStringArray(e.children())) + " = 1"
+ case "CASE_WHEN" =>
+ // Since MsSqlServer cannot handle boolean expressions inside
+ // a CASE WHEN, it is necessary to convert those to another
+ // CASE WHEN expression that will return 1 or 0 depending on
+ // the result.
+ // Example:
+ // In: ... CASE WHEN a = b THEN c = d ... END
+ // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END
... END = 1
+ val stringArray = e.children().grouped(2).flatMap {
+ case Array(whenExpression, thenExpression) =>
+ Array(inputToSQL(whenExpression),
inputToSQLNoBool(thenExpression))
+ case Array(elseExpression) =>
+ Array(inputToSQLNoBool(elseExpression))
+ }.toArray
+
+ visitCaseWhen(stringArray) + " = 1"
case _ => super.build(expr)
}
case _ => super.build(expr)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]