This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 6d5414f2672 [SPARK-42851][SQL] Guard EquivalentExpressions.addExpr() 
with supportedExpression()
6d5414f2672 is described below

commit 6d5414f2672d7fd9a0c8ffe36feef6b3dfb60c74
Author: Kris Mok <[email protected]>
AuthorDate: Tue Mar 21 21:27:49 2023 +0800

    [SPARK-42851][SQL] Guard EquivalentExpressions.addExpr() with 
supportedExpression()
    
    ### What changes were proposed in this pull request?
    
    In `EquivalentExpressions.addExpr()`, add a guard `supportedExpression()` 
to make it consistent with `addExprTree()` and `getExprState()`.
    
    ### Why are the changes needed?
    
    This fixes a regression caused by 
https://github.com/apache/spark/pull/39010 which added the 
`supportedExpression()` to `addExprTree()` and `getExprState()` but not 
`addExpr()`.
    
    One example of a use case affected by the inconsistency is the 
`PhysicalAggregation` pattern in physical planning. There, it calls `addExpr()` 
to deduplicate the aggregate expressions, and then calls `getExprState()` to 
deduplicate the result expressions. Guarding inconsistently will cause the 
aggregate and result expressions go out of sync, eventually resulting in query 
execution error (or whole-stage codegen error).
    
    ### Does this PR introduce _any_ user-facing change?
    
    This fixes a regression affecting Spark 3.3.2+, where it may manifest as an 
error running aggregate operators with higher-order functions.
    
    Example running the SQL command:
    ```sql
    select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) 
from range(2)
    ```
    example error message before the fix:
    ```
    java.lang.IllegalStateException: Couldn't find max(transform(array(id#0L), 
lambdafunction(lambda x#2L, lambda x#2L, false)))#4 in 
[max(transform(array(id#0L), lambdafunction(lambda x#1L, lambda x#1L, 
false)))#3]
    ```
    after the fix this error is gone.
    
    ### How was this patch tested?
    
    Added new test cases to `SubexpressionEliminationSuite` for the immediate 
issue, and to `DataFrameAggregateSuite` for an example of user-visible symptom.
    
    Closes #40473 from rednaxelafx/spark-42851.
    
    Authored-by: Kris Mok <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit ef0a76eeea30fabb04499908b04124464225f5fd)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/expressions/EquivalentExpressions.scala   |  6 +++++-
 .../expressions/SubexpressionEliminationSuite.scala    | 18 +++++++++++++++++-
 .../org/apache/spark/sql/DataFrameAggregateSuite.scala |  7 +++++++
 3 files changed, 29 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 3ffd9f9d887..f47391c0492 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -40,7 +40,11 @@ class EquivalentExpressions {
    * Returns true if there was already a matching expression.
    */
   def addExpr(expr: Expression): Boolean = {
-    updateExprInMap(expr, equivalenceMap)
+    if (supportedExpression(expr)) {
+      updateExprInMap(expr, equivalenceMap)
+    } else {
+      false
+    }
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index b16629f59aa..44d8ea3a112 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
+import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, 
ObjectType}
 
 class SubexpressionEliminationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   test("Semantic equals and hash") {
@@ -449,6 +449,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite 
with ExpressionEvalHel
     assert(e2.getCommonSubexpressions.size == 1)
     assert(e2.getCommonSubexpressions.head == add)
   }
+
+  test("SPARK-42851: Handle supportExpression consistently across add and 
get") {
+    val expr = {
+      val function = (lambda: Expression) => Add(lambda, Literal(1))
+      val elementType = IntegerType
+      val colClass = classOf[Array[Int]]
+      val inputType = ObjectType(colClass)
+      val inputObject = BoundReference(0, inputType, nullable = true)
+      objects.MapObjects(function, inputObject, elementType, true, 
Option(colClass))
+    }
+    val equivalence = new EquivalentExpressions
+    equivalence.addExpr(expr)
+    val hasMatching = equivalence.addExpr(expr)
+    val cseState = equivalence.getExprState(expr)
+    assert(hasMatching == cseState.isDefined)
+  }
 }
 
 case class CodegenFallbackExpression(child: Expression)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index cc4ac37904a..ea5e47ede55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1540,6 +1540,13 @@ class DataFrameAggregateSuite extends QueryTest
     )
     checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil)
   }
+
+  test("SPARK-42851: common subexpression should consistently handle aggregate 
and result exprs") {
+    val res = sql(
+      "select max(transform(array(id), x -> x)), max(transform(array(id), x -> 
x)) from range(2)"
+    )
+    checkAnswer(res, Row(Array(1), Array(1)))
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to