This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fd8230b84975 [SPARK-49229][CONNECT] Deduplicate Scala UDF handling in
the SparkConnectPlanner
fd8230b84975 is described below
commit fd8230b84975f47b4ccf4308856078831d9df365
Author: Paddy Xu <[email protected]>
AuthorDate: Tue Dec 31 09:14:18 2024 +0900
[SPARK-49229][CONNECT] Deduplicate Scala UDF handling in the
SparkConnectPlanner
### What changes were proposed in this pull request?
This PR removes some duplicate codes from `transformScalaFunction` and
`transformScalaUDF` methods of `SparkConnectPlanner`.
### Why are the changes needed?
Keep the code tidy.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49324 from xupefei/udf-handling-deduplicate.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 44 +++++++++++-----------
1 file changed, 23 insertions(+), 21 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 8bb5e54c36cc..628b758dd4e2 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -77,7 +77,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator,
SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl,
TypedAggUtils}
+import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl,
TypedAggUtils, UserDefinedFunctionUtils}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode,
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -1727,34 +1727,36 @@ class SparkConnectPlanner(
}
/**
- * Translates a Scala user-defined function from proto to the Catalyst
expression.
+ * Translates a Scala user-defined function or aggregator from proto to the
corresponding
+ * Catalyst expression.
*
* @param fun
- * Proto representation of the Scala user-defined function.
+ * Proto representation of the Scala user-defined function or aggregator.
* @return
- * ScalaUDF.
+ * An expression, either a ScalaUDF or a ScalaAggregator.
*/
private def transformScalaUDF(fun: proto.CommonInlineUserDefinedFunction):
Expression = {
- val udf = fun.getScalarScalaUdf
- val udfPacket = unpackUdf(fun)
- if (udf.getAggregate) {
- ScalaAggregator(
- transformScalaFunction(fun).asInstanceOf[UserDefinedAggregator[Any,
Any, Any]],
- fun.getArgumentsList.asScala.map(transformExpression).toSeq)
- .toAggregateExpression()
- } else {
- ScalaUDF(
- function = udfPacket.function,
- dataType = transformDataType(udf.getOutputType),
- children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
- inputEncoders = udfPacket.inputEncoders.map(e =>
Try(ExpressionEncoder(e)).toOption),
- outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
- udfName = Option(fun.getFunctionName),
- nullable = udf.getNullable,
- udfDeterministic = fun.getDeterministic)
+ val children = fun.getArgumentsList.asScala.map(transformExpression).toSeq
+ transformScalaFunction(fun) match {
+ case udf: SparkUserDefinedFunction =>
+ UserDefinedFunctionUtils.toScalaUDF(udf, children)
+ case uda: UserDefinedAggregator[_, _, _] =>
+ ScalaAggregator(uda, children).toAggregateExpression()
+ case other =>
+ throw InvalidPlanInput(
+ s"Unsupported UserDefinedFunction implementation: ${other.getClass}")
}
}
+ /**
+ * Translates a Scala user-defined function or aggregator. from proto to a
UserDefinedFunction.
+ *
+ * @param fun
+ * Proto representation of the Scala user-defined function or aggregator.
+ * @return
+ * A concrete UserDefinedFunction implementation, either a
SparkUserDefinedFunction or a
+ * UserDefinedAggregator.
+ */
private def transformScalaFunction(
fun: proto.CommonInlineUserDefinedFunction): UserDefinedFunction = {
val udf = fun.getScalarScalaUdf
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]