This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new cf3d10156ff [SPARK-44930][SQL] Deterministic ApplyFunctionExpression 
should be foldable
cf3d10156ff is described below

commit cf3d10156ff86e1b0c27cdc5706b345e84867cf9
Author: xianyangliu <[email protected]>
AuthorDate: Fri Aug 25 15:01:35 2023 +0800

    [SPARK-44930][SQL] Deterministic ApplyFunctionExpression should be foldable
    
    ### What changes were proposed in this pull request?
    
    Currently, ApplyFunctionExpression is unfoldable because inherits the 
default value from Expression.  However, it should be foldable for a 
deterministic ApplyFunctionExpression.
    
    ### Why are the changes needed?
    
    This could help optimize the usage for V2 UDF applying to constant 
expressions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42629 from ConeyLiu/constant-fold-v2-udf.
    
    Authored-by: xianyangliu <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 994389f42a40d292a72482e3d76d29bada82d8ec)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/ApplyFunctionExpression.scala      |  1 +
 .../sql/connector/DataSourceV2FunctionSuite.scala  | 22 ++++++++++++----------
 2 files changed, 13 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
index da4000f53e3..a1815cf3b3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
@@ -33,6 +33,7 @@ case class ApplyFunctionExpression(
   override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
   override lazy val deterministic: Boolean = function.isDeterministic &&
       children.forall(_.deterministic)
+  override def foldable: Boolean = deterministic && children.forall(_.foldable)
 
   private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index 32391eac9a8..b74d7318a92 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -24,7 +24,6 @@ import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
 import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
 import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, 
NO_CODEGEN}
@@ -322,14 +321,7 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
   test("scalar function: bad magic method") {
     
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
     addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenBadMagic))
-    // TODO assign a error-classes name
-    checkError(
-      exception = intercept[SparkException] {
-        sql("SELECT testcat.ns.strlen('abc')").collect()
-      },
-      errorClass = null,
-      parameters = Map.empty
-    )
+    intercept[UnsupportedOperationException](sql("SELECT 
testcat.ns.strlen('abc')").collect())
   }
 
   test("scalar function: bad magic method with default impl") {
@@ -341,7 +333,7 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase 
{
   test("scalar function: no implementation found") {
     
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
     addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenNoImpl))
-    intercept[SparkException](sql("SELECT testcat.ns.strlen('abc')").collect())
+    intercept[UnsupportedOperationException](sql("SELECT 
testcat.ns.strlen('abc')").collect())
   }
 
   test("scalar function: invalid parameter type or length") {
@@ -688,6 +680,16 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
     }
   }
 
+  test("SPARK-44930: Fold deterministic ApplyFunctionExpression") {
+    
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
+    addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(StrLenDefault))
+
+    val df1 = sql("SELECT testcat.ns.strlen('abc') as col1")
+    val df2 = sql("SELECT 3 as col1")
+    comparePlans(df1.queryExecution.optimizedPlan, 
df2.queryExecution.optimizedPlan)
+    checkAnswer(df1, Row(3) :: Nil)
+  }
+
   private case object StrLenDefault extends ScalarFunction[Int] {
     override def inputTypes(): Array[DataType] = Array(StringType)
     override def resultType(): DataType = IntegerType


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to