This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fd8230b84975 [SPARK-49229][CONNECT] Deduplicate Scala UDF handling in 
the SparkConnectPlanner
fd8230b84975 is described below

commit fd8230b84975f47b4ccf4308856078831d9df365
Author: Paddy Xu <[email protected]>
AuthorDate: Tue Dec 31 09:14:18 2024 +0900

    [SPARK-49229][CONNECT] Deduplicate Scala UDF handling in the 
SparkConnectPlanner
    
    ### What changes were proposed in this pull request?
    
    This PR removes some duplicate codes from `transformScalaFunction` and 
`transformScalaUDF` methods of `SparkConnectPlanner`.
    
    ### Why are the changes needed?
    
    Keep the code tidy.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49324 from xupefei/udf-handling-deduplicate.
    
    Authored-by: Paddy Xu <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 44 +++++++++++-----------
 1 file changed, 23 insertions(+), 21 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 8bb5e54c36cc..628b758dd4e2 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -77,7 +77,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
 import 
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
 import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, 
SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, 
TypedAggUtils}
+import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, 
TypedAggUtils, UserDefinedFunctionUtils}
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -1727,34 +1727,36 @@ class SparkConnectPlanner(
   }
 
   /**
-   * Translates a Scala user-defined function from proto to the Catalyst 
expression.
+   * Translates a Scala user-defined function or aggregator from proto to the 
corresponding
+   * Catalyst expression.
    *
    * @param fun
-   *   Proto representation of the Scala user-defined function.
+   *   Proto representation of the Scala user-defined function or aggregator.
    * @return
-   *   ScalaUDF.
+   *   An expression, either a ScalaUDF or a ScalaAggregator.
    */
   private def transformScalaUDF(fun: proto.CommonInlineUserDefinedFunction): 
Expression = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket = unpackUdf(fun)
-    if (udf.getAggregate) {
-      ScalaAggregator(
-        transformScalaFunction(fun).asInstanceOf[UserDefinedAggregator[Any, 
Any, Any]],
-        fun.getArgumentsList.asScala.map(transformExpression).toSeq)
-        .toAggregateExpression()
-    } else {
-      ScalaUDF(
-        function = udfPacket.function,
-        dataType = transformDataType(udf.getOutputType),
-        children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
-        inputEncoders = udfPacket.inputEncoders.map(e => 
Try(ExpressionEncoder(e)).toOption),
-        outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
-        udfName = Option(fun.getFunctionName),
-        nullable = udf.getNullable,
-        udfDeterministic = fun.getDeterministic)
+    val children = fun.getArgumentsList.asScala.map(transformExpression).toSeq
+    transformScalaFunction(fun) match {
+      case udf: SparkUserDefinedFunction =>
+        UserDefinedFunctionUtils.toScalaUDF(udf, children)
+      case uda: UserDefinedAggregator[_, _, _] =>
+        ScalaAggregator(uda, children).toAggregateExpression()
+      case other =>
+        throw InvalidPlanInput(
+          s"Unsupported UserDefinedFunction implementation: ${other.getClass}")
     }
   }
 
+  /**
+   * Translates a Scala user-defined function or aggregator. from proto to a 
UserDefinedFunction.
+   *
+   * @param fun
+   *   Proto representation of the Scala user-defined function or aggregator.
+   * @return
+   *   A concrete UserDefinedFunction implementation, either a 
SparkUserDefinedFunction or a
+   *   UserDefinedAggregator.
+   */
   private def transformScalaFunction(
       fun: proto.CommonInlineUserDefinedFunction): UserDefinedFunction = {
     val udf = fun.getScalarScalaUdf


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to