This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e1c3dc8004b [SPARK-50087] Robust handling of boolean expressions in 
CASE WHEN for MsSqlServer and future connectors
2e1c3dc8004b is described below

commit 2e1c3dc8004b4f003cde8dfae6857f5bef4bb170
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Nov 21 20:56:59 2024 +0800

    [SPARK-50087] Robust handling of boolean expressions in CASE WHEN for 
MsSqlServer and future connectors
    
    ### What changes were proposed in this pull request?
    This PR proposes to propagate the `isPredicate` info in 
`V2ExpressionBuilder` and wrap the children of CASE WHEN expression (only 
`Predicate`s) with `IIF(<>, 1, 0)` for MsSqlServer. This is done to force 
returning an int instead of a boolean, as SqlServer cannot handle boolean 
expressions as a return type in CASE WHEN.
    
    E.g.
    ```CASE WHEN ... ELSE a = b END```
    
    Old behavior:
    ```CASE WHEN ... ELSE a = b END = 1```
    
    New behavior:
    Since in SqlServer a `= 1` is appended to the CASE WHEN, THEN and ELSE 
blocks must return an int. Therefore the final expression becomes:
    ```CASE WHEN ... ELSE IIF(a = b, 1, 0) END = 1```
    
    ### Why are the changes needed?
    A user cannot work with an MsSqlServer data with CASE WHEN clauses or IF 
clauses if they wish to return a boolean value.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added tests to MsSqlServerIntegrationSuite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48621 from andrej-db/SPARK-50087-CaseWhen.
    
    Lead-authored-by: Wenchen Fan <[email protected]>
    Co-authored-by: andrej-db <[email protected]>
    Co-authored-by: Andrej Gobeljić <[email protected]>
    Co-authored-by: andrej-gobeljic_data <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/jdbc/v2/MsSqlServerIntegrationSuite.scala  | 79 ++++++++++++++++++++++
 .../sql/catalyst/util/V2ExpressionBuilder.scala    |  6 +-
 .../org/apache/spark/sql/jdbc/JdbcDialects.scala   | 13 ++++
 .../apache/spark/sql/jdbc/MsSqlServerDialect.scala | 24 +++++--
 4 files changed, 114 insertions(+), 8 deletions(-)

diff --git 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index d884ad4c6246..fd7efb1efb76 100644
--- 
a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ 
b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.jdbc.v2
 import java.sql.Connection
 
 import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException}
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
 import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker
 import org.apache.spark.sql.types._
@@ -37,6 +41,17 @@ import org.apache.spark.tags.DockerTest
 @DockerTest
 class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with 
V2JDBCTest {
 
+  def getExternalEngineQuery(executedPlan: SparkPlan): String = {
+    
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
+  }
+
+  def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = {
+    val queryNode = executedPlan.collect { case r: RowDataSourceScanExec =>
+      r
+    }.head
+    queryNode.rdd
+  }
+
   override def excluded: Seq[String] = Seq(
     "simple scan with OFFSET",
     "simple scan with LIMIT and OFFSET",
@@ -146,4 +161,68 @@ class MsSqlServerIntegrationSuite extends 
DockerJDBCIntegrationV2Suite with V2JD
         |""".stripMargin)
     assert(df.collect().length == 2)
   }
+
+  test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") {
+    val df = sql(
+      s"""|SELECT * FROM $catalogName.employee
+          |WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 
'Wizard') END
+          |""".stripMargin
+    )
+
+    // scalastyle:off
+    assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+      """SELECT  "dept","name","salary","bonus" FROM "employee"  WHERE (CASE 
WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 
'Wizard'), 1, 0) END = 1)  """
+    )
+    // scalastyle:on
+    df.collect()
+  }
+
+  test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true 
test") {
+    val df = sql(
+      s"""|SELECT * FROM $catalogName.employee
+          |WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) 
END
+          |""".stripMargin
+    )
+
+    // scalastyle:off
+    assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+      """SELECT  "dept","name","salary","bonus" FROM "employee"  WHERE (CASE 
WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1)  """
+    )
+    // scalastyle:on
+    df.collect()
+  }
+
+  test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") {
+    val df = sql(
+      s"""|SELECT * FROM $catalogName.employee
+          |WHERE CASE WHEN (name = 'Legolas') THEN
+          | CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 
'Gandalf') END
+          | ELSE (name = 'Sauron') END
+          |""".stripMargin
+    )
+
+    // scalastyle:off
+    assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+      """SELECT  "dept","name","salary","bonus" FROM "employee"  WHERE (CASE 
WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" 
= 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE 
IIF(("name" = 'Sauron'), 1, 0) END = 1)  """
+    )
+    // scalastyle:on
+    df.collect()
+  }
+
+  test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") {
+    val df = sql(
+      s"""|SELECT * FROM $catalogName.employee
+          |WHERE CASE WHEN (name = 'Legolas') THEN
+          | CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END
+          | ELSE 'Sauron' END = name
+          |""".stripMargin
+    )
+
+    // scalastyle:off
+    assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
+      """SELECT  "dept","name","salary","bonus" FROM "employee"  WHERE ("name" 
IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' 
THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name")  """
+    )
+    // scalastyle:on
+    df.collect()
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 61a26d7a4fbd..b0ce2bb4293e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -221,8 +221,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) extends L
     case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
     case caseWhen @ CaseWhen(branches, elseValue) =>
       val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
-      val values = branches.map(_._2).flatMap(generateExpression(_))
-      val elseExprOpt = elseValue.flatMap(generateExpression(_))
+      val values = branches.map(_._2).flatMap(generateExpression(_, 
isPredicate))
+      val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
       if (conditions.length == branches.length && values.length == 
branches.length &&
           elseExprOpt.size == elseValue.size) {
         val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
@@ -421,7 +421,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: 
Boolean = false) extends L
       children: Seq[Expression],
       dataType: DataType,
       isPredicate: Boolean): Option[V2Expression] = {
-    val childrenExpressions = children.flatMap(generateExpression(_))
+    val childrenExpressions = children.flatMap(generateExpression(_, 
isPredicate))
     if (childrenExpressions.length == children.length) {
       if (isPredicate && dataType.isInstanceOf[BooleanType]) {
         Some(new V2Predicate(v2ExpressionName, 
childrenExpressions.toArray[V2Expression]))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 3bf1390cb664..81ad1a6d38bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -42,6 +42,7 @@ import 
org.apache.spark.sql.connector.catalog.functions.UnboundFunction
 import org.apache.spark.sql.connector.catalog.index.TableIndex
 import org.apache.spark.sql.connector.expressions.{Expression, Literal, 
NamedReference}
 import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
+import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, 
JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
@@ -377,6 +378,18 @@ abstract class JdbcDialect extends Serializable with 
Logging {
   }
 
   private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
+    // Some dialects do not support boolean type and this convenient util 
function is
+    // provided to generate SQL string without boolean values.
+    protected def inputToSQLNoBool(input: Expression): String = input match {
+      case p: Predicate if p.name() == "ALWAYS_TRUE" => "1"
+      case p: Predicate if p.name() == "ALWAYS_FALSE" => "0"
+      case p: Predicate => predicateToIntSQL(inputToSQL(p))
+      case _ => inputToSQL(input)
+    }
+
+    protected def predicateToIntSQL(input: String): String =
+      "CASE WHEN " + input + " THEN 1 ELSE 0 END"
+
     override def visitLiteral(literal: Literal[_]): String = {
       Option(literal.value()).map(v =>
         compileValue(CatalystTypeConverters.convertToScala(v, 
literal.dataType())).toString)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index 7d476d43e5c7..7d339a90db8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -59,6 +59,8 @@ private case class MsSqlServerDialect() extends JdbcDialect 
with NoLegacyJDBCErr
     supportedFunctions.contains(funcName)
 
   class MsSqlServerSQLBuilder extends JDBCSQLBuilder {
+    override protected def predicateToIntSQL(input: String): String =
+      "IIF(" + input + ", 1, 0)"
     override def visitSortOrder(
         sortKey: String, sortDirection: SortDirection, nullOrdering: 
NullOrdering): String = {
       (sortDirection, nullOrdering) match {
@@ -87,12 +89,24 @@ private case class MsSqlServerDialect() extends JdbcDialect 
with NoLegacyJDBCErr
       expr match {
         case e: Predicate => e.name() match {
           case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
-            val Array(l, r) = e.children().map {
-              case p: Predicate => s"CASE WHEN ${inputToSQL(p)} THEN 1 ELSE 0 
END"
-              case o => inputToSQL(o)
-            }
+            val Array(l, r) = e.children().map(inputToSQLNoBool)
             visitBinaryComparison(e.name(), l, r)
-          case "CASE_WHEN" => 
visitCaseWhen(expressionsToStringArray(e.children())) + " = 1"
+          case "CASE_WHEN" =>
+            // Since MsSqlServer cannot handle boolean expressions inside
+            // a CASE WHEN, it is necessary to convert those to another
+            // CASE WHEN expression that will return 1 or 0 depending on
+            // the result.
+            // Example:
+            // In:  ... CASE WHEN a = b THEN c = d ... END
+            // Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END 
... END = 1
+            val stringArray = e.children().grouped(2).flatMap {
+              case Array(whenExpression, thenExpression) =>
+                Array(inputToSQL(whenExpression), 
inputToSQLNoBool(thenExpression))
+              case Array(elseExpression) =>
+                Array(inputToSQLNoBool(elseExpression))
+            }.toArray
+
+            visitCaseWhen(stringArray) + " = 1"
           case _ => super.build(expr)
         }
         case _ => super.build(expr)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to