This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e363e9d1b7b [SPARK-42124][PYTHON][CONNECT] Scalar Inline Python UDF in 
Spark Connect
e363e9d1b7b is described below

commit e363e9d1b7b2ff19c7ff39760521f83481f78c1c
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Jan 25 19:42:46 2023 +0900

    [SPARK-42124][PYTHON][CONNECT] Scalar Inline Python UDF in Spark Connect
    
    ### What changes were proposed in this pull request?
    Support scalar inline user-defined function of Python(a.k.a., unregistered 
Python UDF) in Spark Connect.
    
    Currently, the user-specified return type must be of 
`pyspark.sql.types.DataType`.
    
    There will be follow-up PRs on:
    - Support Pandas UDF 
[jira](https://issues.apache.org/jira/browse/SPARK-42125)
    - Support user-specified return type in DDL-formatted strings 
[jira](https://issues.apache.org/jira/browse/SPARK-42126)
    ### Why are the changes needed?
    Feature parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Unregistered Python UDF is supported now, as shown below:
    
    ```
    >>> spark.range(2).withColumn('plus_one', udf(lambda x: x + 1)('id')).show()
    +---+--------+
    | id|plus_one|
    +---+--------+
    |  0|       1|
    |  1|       2|
    +---+--------+
    
    >>> udf(LongType())
    ... def f(x):
    ...   return x + 1
    ...
    >>> spark.range(2).withColumn('plus_one', f('id')).show()
    +---+--------+
    | id|plus_one|
    +---+--------+
    |  0|       1|
    |  1|       2|
    +---+--------+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #39585 from xinrong-meng/connect_udf.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/expressions.proto  |  24 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  65 ++++++++
 .../messages/ConnectProtoMessagesSuite.scala       |  34 +++++
 python/pyspark/sql/connect/_typing.py              |  21 ++-
 python/pyspark/sql/connect/expressions.py          |  62 ++++++++
 python/pyspark/sql/connect/functions.py            |  30 +++-
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 118 +++++++++------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  88 +++++++++++
 python/pyspark/sql/connect/udf.py                  | 165 +++++++++++++++++++++
 python/pyspark/sql/functions.py                    |   4 +
 .../sql/tests/connect/test_connect_function.py     |  47 +++++-
 11 files changed, 604 insertions(+), 54 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index f7feae0e2f0..7ae0a6c5008 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -44,6 +44,7 @@ message Expression {
     UnresolvedExtractValue unresolved_extract_value = 12;
     UpdateFields update_fields = 13;
     UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
+    ScalarInlineUserDefinedFunction scalar_inline_user_defined_function = 15;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -295,3 +296,26 @@ message Expression {
     repeated string name_parts = 1;
   }
 }
+
+message ScalarInlineUserDefinedFunction {
+  // (Required) Name of the user-defined function.
+  string function_name = 1;
+  // (Required) Indicate if the user-defined function is deterministic.
+  bool deterministic = 2;
+  // (Optional) Function arguments. Empty arguments are allowed.
+  repeated Expression arguments = 3;
+  // (Required) Indicate the function type of the user-defined function.
+  oneof function {
+    PythonUDF python_udf = 4;
+  }
+}
+
+message PythonUDF {
+  // (Required) Output type of the Python UDF
+  string output_type = 1;
+  // (Required) EvalType of the Python UDF
+  int32 eval_type = 2;
+  // (Required) The encoded commands of the Python UDF
+  bytes command = 3;
+}
+
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f95f065c5b3..dc921cee282 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -742,6 +742,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformWindowExpression(exp.getWindow)
       case proto.Expression.ExprTypeCase.EXTENSION =>
         transformExpressionPlugin(exp.getExtension)
+      case proto.Expression.ExprTypeCase.SCALAR_INLINE_USER_DEFINED_FUNCTION =>
+        
transformScalarInlineUserDefinedFunction(exp.getScalarInlineUserDefinedFunction)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -816,6 +818,65 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  /**
+   * Translates a user-defined function from proto to the Catalyst expression.
+   *
+   * @param fun
+   *   Proto representation of the function call.
+   * @return
+   *   Expression.
+   */
+  private def transformScalarInlineUserDefinedFunction(
+      fun: proto.ScalarInlineUserDefinedFunction): Expression = {
+    fun.getFunctionCase match {
+      case proto.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        transformPythonUDF(fun)
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${fun.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  /**
+   * Translates a Python user-defined function from proto to the Catalyst 
expression.
+   *
+   * @param fun
+   *   Proto representation of the Python user-defined function.
+   * @return
+   *   PythonUDF.
+   */
+  private def transformPythonUDF(fun: proto.ScalarInlineUserDefinedFunction): 
PythonUDF = {
+    val udf = fun.getPythonUdf
+    PythonUDF(
+      name = fun.getFunctionName,
+      func = transformPythonFunction(udf),
+      dataType = DataType.parseTypeWithFallback(
+        schema = udf.getOutputType,
+        parser = DataType.fromDDL,
+        fallbackParser = DataType.fromJson) match {
+        case s: DataType => s
+        case other => throw InvalidPlanInput(s"Invalid return type $other")
+      },
+      children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+      evalType = udf.getEvalType,
+      udfDeterministic = fun.getDeterministic)
+  }
+
+  private def transformPythonFunction(fun: proto.PythonUDF): 
SimplePythonFunction = {
+    SimplePythonFunction(
+      command = fun.getCommand.toByteArray,
+      // Empty environment variables
+      envVars = Maps.newHashMap(),
+      // No imported Python libraries
+      pythonIncludes = Lists.newArrayList(),
+      pythonExec = pythonExec,
+      pythonVer = "3.9", // TODO(SPARK-40532) This needs to be an actual 
Python version.
+      // Empty broadcast variables
+      broadcastVars = Lists.newArrayList(),
+      // Null accumulator
+      accumulator = null)
+  }
+
   /**
    * Translates a LambdaFunction from proto to the Catalyst expression.
    */
@@ -1351,11 +1412,15 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit 
= {
     val function = SimplePythonFunction(
       cf.getSerializedFunction.toByteArray,
+      // Empty environment variables
       Maps.newHashMap(),
+      // No imported Python libraries
       Lists.newArrayList(),
       pythonExec,
       "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+      // Empty broadcast variables
       Lists.newArrayList(),
+      // Null accumulator
       null)
 
     val udf = UserDefinedPythonFunction(
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 08f12aa6d08..3d8fae83428 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.spark.sql.connect.messages
 
+import com.google.protobuf.ByteString
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
 
@@ -48,4 +50,36 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
     assert(extLit.getLiteral.hasInteger)
     assert(extLit.getLiteral.getInteger == 32)
   }
+
+  test("ScalarInlineUserDefinedFunction") {
+    val arguments = proto.Expression
+      .newBuilder()
+      .setUnresolvedAttribute(
+        
proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("id"))
+      .build()
+
+    val pythonUdf = proto.PythonUDF
+      .newBuilder()
+      .setEvalType(100)
+      .setOutputType("\"integer\"")
+      .setCommand(ByteString.copyFrom("command".getBytes()))
+      .build()
+
+    val scalarInlineUserDefinedFunctionExpr = proto.Expression
+      .newBuilder()
+      .setScalarInlineUserDefinedFunction(
+        proto.ScalarInlineUserDefinedFunction
+          .newBuilder()
+          .setFunctionName("f")
+          .setDeterministic(true)
+          .addArguments(arguments)
+          .setPythonUdf(pythonUdf))
+      .build()
+
+    val fun = 
scalarInlineUserDefinedFunctionExpr.getScalarInlineUserDefinedFunction()
+    assert(fun.getFunctionName == "f")
+    assert(fun.getDeterministic == true)
+    assert(fun.getArgumentsCount == 1)
+    assert(fun.hasPythonUdf == true)
+  }
 }
diff --git a/python/pyspark/sql/connect/_typing.py 
b/python/pyspark/sql/connect/_typing.py
index 29a14384c82..66b08d898fe 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -22,11 +22,12 @@ if sys.version_info >= (3, 8):
 else:
     from typing_extensions import Protocol
 
-from typing import Union, Optional
+from typing import Any, Callable, Union, Optional
 import datetime
 import decimal
 
 from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.types import DataType
 
 
 ColumnOrName = Union[Column, str]
@@ -41,6 +42,24 @@ DecimalLiteral = decimal.Decimal
 
 DateTimeLiteral = Union[datetime.datetime, datetime.date]
 
+DataTypeOrString = Union[DataType, str]
+
+
+class UserDefinedFunctionLike(Protocol):
+    func: Callable[..., Any]
+    evalType: int
+    deterministic: bool
+
+    @property
+    def returnType(self) -> DataType:
+        ...
+
+    def __call__(self, *args: ColumnOrName) -> Column:
+        ...
+
+    def asNondeterministic(self) -> "UserDefinedFunctionLike":
+        ...
+
 
 class UserDefinedFunctionCallable(Protocol):
     def __call__(self, *_: ColumnOrName) -> Column:
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index c8d361af2a5..0fa67a5f8d0 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -480,6 +480,68 @@ class UnresolvedFunction(Expression):
             return f"{self._name}({', '.join([str(arg) for arg in 
self._args])})"
 
 
+class PythonUDF:
+    """Represents a Python user-defined function."""
+
+    def __init__(
+        self,
+        output_type: str,
+        eval_type: int,
+        command: bytes,
+    ) -> None:
+        self._output_type = output_type
+        self._eval_type = eval_type
+        self._command = command
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
+        expr = proto.PythonUDF()
+        expr.output_type = self._output_type
+        expr.eval_type = self._eval_type
+        expr.command = self._command
+        return expr
+
+    def __repr__(self) -> str:
+        return (
+            f"{self._output_type}, {self._eval_type}, "
+            f"{self._command}"  # type: ignore[str-bytes-safe]
+        )
+
+
+class ScalarInlineUserDefinedFunction(Expression):
+    """Represents a scalar inline user-defined function of any programming 
languages."""
+
+    def __init__(
+        self,
+        function_name: str,
+        deterministic: bool,
+        arguments: Sequence[Expression],
+        function: PythonUDF,
+    ):
+        self._function_name = function_name
+        self._deterministic = deterministic
+        self._arguments = arguments
+        self._function = function
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        expr = proto.Expression()
+        expr.scalar_inline_user_defined_function.function_name = 
self._function_name
+        expr.scalar_inline_user_defined_function.deterministic = 
self._deterministic
+        if len(self._arguments) > 0:
+            expr.scalar_inline_user_defined_function.arguments.extend(
+                [arg.to_plan(session) for arg in self._arguments]
+            )
+        expr.scalar_inline_user_defined_function.python_udf.CopyFrom(
+            self._function.to_plan(session)
+        )
+        return expr
+
+    def __repr__(self) -> str:
+        return (
+            f"{self._function_name}({', '.join([str(arg) for arg in 
self._arguments])}), "
+            f"{self._deterministic}, {self._function}"
+        )
+
+
 class WithField(Expression):
     def __init__(
         self,
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 75f6ba1ff64..ee7b45622b3 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -17,6 +17,7 @@
 
 import inspect
 import warnings
+import functools
 from typing import (
     Any,
     Dict,
@@ -49,11 +50,16 @@ from pyspark.sql.connect.expressions import (
     LambdaFunction,
     UnresolvedNamedLambdaVariable,
 )
+from pyspark.sql.connect.udf import _create_udf
 from pyspark.sql import functions as pysparkfuncs
-from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType
+from pyspark.sql.types import _from_numpy_type, DataType, StructType, 
ArrayType, StringType
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect._typing import ColumnOrName
+    from pyspark.sql.connect._typing import (
+        ColumnOrName,
+        DataTypeOrString,
+        UserDefinedFunctionLike,
+    )
     from pyspark.sql.connect.dataframe import DataFrame
 
 
@@ -2401,8 +2407,24 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
 unwrap_udt.__doc__ = pysparkfuncs.unwrap_udt.__doc__
 
 
-def udf(*args: Any, **kwargs: Any) -> None:
-    raise NotImplementedError("udf() is not implemented.")
+def udf(
+    f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
+    returnType: "DataTypeOrString" = StringType(),
+) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], 
"UserDefinedFunctionLike"]]:
+    from pyspark.rdd import PythonEvalType
+
+    if f is None or isinstance(f, (str, DataType)):
+        # If DataType has been passed as a positional argument
+        # for decorator use it as a returnType
+        return_type = f or returnType
+        return functools.partial(
+            _create_udf, returnType=return_type, 
evalType=PythonEvalType.SQL_BATCHED_UDF
+        )
+    else:
+        return _create_udf(f=f, returnType=returnType, 
evalType=PythonEvalType.SQL_BATCHED_UDF)
+
+
+udf.__doc__ = pysparkfuncs.udf.__doc__
 
 
 def pandas_udf(*args: Any, **kwargs: Any) -> None:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 87c16964102..0b2419fee35 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 
@@ -61,6 +61,10 @@ _EXPRESSION_LAMBDAFUNCTION = 
_EXPRESSION.nested_types_by_name["LambdaFunction"]
 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE = _EXPRESSION.nested_types_by_name[
     "UnresolvedNamedLambdaVariable"
 ]
+_SCALARINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
+    "ScalarInlineUserDefinedFunction"
+]
+_PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = 
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
     "FrameType"
 ]
@@ -257,52 +261,78 @@ _sym_db.RegisterMessage(Expression.Alias)
 _sym_db.RegisterMessage(Expression.LambdaFunction)
 _sym_db.RegisterMessage(Expression.UnresolvedNamedLambdaVariable)
 
+ScalarInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType(
+    "ScalarInlineUserDefinedFunction",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _SCALARINLINEUSERDEFINEDFUNCTION,
+        "__module__": "spark.connect.expressions_pb2"
+        # 
@@protoc_insertion_point(class_scope:spark.connect.ScalarInlineUserDefinedFunction)
+    },
+)
+_sym_db.RegisterMessage(ScalarInlineUserDefinedFunction)
+
+PythonUDF = _reflection.GeneratedProtocolMessageType(
+    "PythonUDF",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _PYTHONUDF,
+        "__module__": "spark.connect.expressions_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.PythonUDF)
+    },
+)
+_sym_db.RegisterMessage(PythonUDF)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 4731
-    _EXPRESSION_WINDOW._serialized_start = 1347
-    _EXPRESSION_WINDOW._serialized_end = 2130
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2130
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1904
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2049
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2051
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2130
-    _EXPRESSION_SORTORDER._serialized_start = 2133
-    _EXPRESSION_SORTORDER._serialized_end = 2558
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2363
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2471
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2473
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2558
-    _EXPRESSION_CAST._serialized_start = 2561
-    _EXPRESSION_CAST._serialized_end = 2706
-    _EXPRESSION_LITERAL._serialized_start = 2709
-    _EXPRESSION_LITERAL._serialized_end = 3585
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3352
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3469
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3471
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3569
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3587
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3657
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3660
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3864
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 4184
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 4371
-    _EXPRESSION_ALIAS._serialized_start = 4373
-    _EXPRESSION_ALIAS._serialized_end = 4493
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718
+    _EXPRESSION._serialized_end = 4859
+    _EXPRESSION_WINDOW._serialized_start = 1475
+    _EXPRESSION_WINDOW._serialized_end = 2258
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2258
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2032
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2177
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2179
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2258
+    _EXPRESSION_SORTORDER._serialized_start = 2261
+    _EXPRESSION_SORTORDER._serialized_end = 2686
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2491
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2599
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2601
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2686
+    _EXPRESSION_CAST._serialized_start = 2689
+    _EXPRESSION_CAST._serialized_end = 2834
+    _EXPRESSION_LITERAL._serialized_start = 2837
+    _EXPRESSION_LITERAL._serialized_end = 3713
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3480
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3597
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3785
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3788
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3992
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3994
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4044
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4046
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4128
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4130
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4174
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4177
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4309
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 4312
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4499
+    _EXPRESSION_ALIAS._serialized_start = 4501
+    _EXPRESSION_ALIAS._serialized_end = 4621
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4624
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
+    _SCALARINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
+    _SCALARINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
+    _PYTHONUDF._serialized_start = 5100
+    _PYTHONUDF._serialized_end = 5199
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 45889c1518f..0191a0cdaf4 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -932,6 +932,7 @@ class Expression(google.protobuf.message.Message):
     UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
     UPDATE_FIELDS_FIELD_NUMBER: builtins.int
     UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
+    SCALAR_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
@@ -964,6 +965,8 @@ class Expression(google.protobuf.message.Message):
         self,
     ) -> global___Expression.UnresolvedNamedLambdaVariable: ...
     @property
+    def scalar_inline_user_defined_function(self) -> 
global___ScalarInlineUserDefinedFunction: ...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         relations they can add them here. During the planning the correct 
resolution is done.
@@ -986,6 +989,7 @@ class Expression(google.protobuf.message.Message):
         update_fields: global___Expression.UpdateFields | None = ...,
         unresolved_named_lambda_variable: 
global___Expression.UnresolvedNamedLambdaVariable
         | None = ...,
+        scalar_inline_user_defined_function: 
global___ScalarInlineUserDefinedFunction | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -1005,6 +1009,8 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
+            "scalar_inline_user_defined_function",
+            b"scalar_inline_user_defined_function",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1042,6 +1048,8 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
+            "scalar_inline_user_defined_function",
+            b"scalar_inline_user_defined_function",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1079,7 +1087,87 @@ class Expression(google.protobuf.message.Message):
         "unresolved_extract_value",
         "update_fields",
         "unresolved_named_lambda_variable",
+        "scalar_inline_user_defined_function",
         "extension",
     ] | None: ...
 
 global___Expression = Expression
+
+class ScalarInlineUserDefinedFunction(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    FUNCTION_NAME_FIELD_NUMBER: builtins.int
+    DETERMINISTIC_FIELD_NUMBER: builtins.int
+    ARGUMENTS_FIELD_NUMBER: builtins.int
+    PYTHON_UDF_FIELD_NUMBER: builtins.int
+    function_name: builtins.str
+    """(Required) Name of the user-defined function."""
+    deterministic: builtins.bool
+    """(Required) Indicate if the user-defined function is deterministic."""
+    @property
+    def arguments(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+        """(Optional) Function arguments. Empty arguments are allowed."""
+    @property
+    def python_udf(self) -> global___PythonUDF: ...
+    def __init__(
+        self,
+        *,
+        function_name: builtins.str = ...,
+        deterministic: builtins.bool = ...,
+        arguments: collections.abc.Iterable[global___Expression] | None = ...,
+        python_udf: global___PythonUDF | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal["function", b"function", 
"python_udf", b"python_udf"],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "arguments",
+            b"arguments",
+            "deterministic",
+            b"deterministic",
+            "function",
+            b"function",
+            "function_name",
+            b"function_name",
+            "python_udf",
+            b"python_udf",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["function", b"function"]
+    ) -> typing_extensions.Literal["python_udf"] | None: ...
+
+global___ScalarInlineUserDefinedFunction = ScalarInlineUserDefinedFunction
+
+class PythonUDF(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    OUTPUT_TYPE_FIELD_NUMBER: builtins.int
+    EVAL_TYPE_FIELD_NUMBER: builtins.int
+    COMMAND_FIELD_NUMBER: builtins.int
+    output_type: builtins.str
+    """(Required) Output type of the Python UDF"""
+    eval_type: builtins.int
+    """(Required) EvalType of the Python UDF"""
+    command: builtins.bytes
+    """(Required) The encoded commands of the Python UDF"""
+    def __init__(
+        self,
+        *,
+        output_type: builtins.str = ...,
+        eval_type: builtins.int = ...,
+        command: builtins.bytes = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "command", b"command", "eval_type", b"eval_type", "output_type", 
b"output_type"
+        ],
+    ) -> None: ...
+
+global___PythonUDF = PythonUDF
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
new file mode 100644
index 00000000000..4a465084838
--- /dev/null
+++ b/python/pyspark/sql/connect/udf.py
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined function related classes and functions
+"""
+import functools
+from typing import Callable, Any, TYPE_CHECKING, Optional
+
+from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.connect.expressions import (
+    ColumnReference,
+    PythonUDF,
+    ScalarInlineUserDefinedFunction,
+)
+from pyspark.sql.connect.column import Column
+from pyspark.sql.types import DataType, StringType
+
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import (
+        ColumnOrName,
+        DataTypeOrString,
+        UserDefinedFunctionLike,
+    )
+    from pyspark.sql.types import StringType
+
+
+def _create_udf(
+    f: Callable[..., Any],
+    returnType: "DataTypeOrString",
+    evalType: int,
+    name: Optional[str] = None,
+    deterministic: bool = True,
+) -> "UserDefinedFunctionLike":
+    # Set the name of the UserDefinedFunction object to be the name of 
function f
+    udf_obj = UserDefinedFunction(
+        f, returnType=returnType, name=name, evalType=evalType, 
deterministic=deterministic
+    )
+    return udf_obj._wrapped()
+
+
+class UserDefinedFunction:
+    """
+    User defined function in Python
+
+    Notes
+    -----
+    The constructor of this class is not supposed to be directly called.
+    Use :meth:`pyspark.sql.functions.udf` or 
:meth:`pyspark.sql.functions.pandas_udf`
+    to create this instance.
+    """
+
+    def __init__(
+        self,
+        func: Callable[..., Any],
+        returnType: "DataTypeOrString" = StringType(),
+        name: Optional[str] = None,
+        evalType: int = 100,
+        deterministic: bool = True,
+    ):
+        if not callable(func):
+            raise TypeError(
+                "Invalid function: not a function or callable (__call__ is not 
defined): "
+                "{0}".format(type(func))
+            )
+
+        if not isinstance(returnType, (DataType, str)):
+            raise TypeError(
+                "Invalid return type: returnType should be DataType or str "
+                "but is {}".format(returnType)
+            )
+
+        if not isinstance(evalType, int):
+            raise TypeError(
+                "Invalid evaluation type: evalType should be an int but is 
{}".format(evalType)
+            )
+
+        self.func = func
+        self._returnType = returnType
+        self._name = name or (
+            func.__name__ if hasattr(func, "__name__") else 
func.__class__.__name__
+        )
+        self.evalType = evalType
+        self.deterministic = deterministic
+
+    def __call__(self, *cols: "ColumnOrName") -> Column:
+        arg_cols = [
+            col if isinstance(col, Column) else Column(ColumnReference(col)) 
for col in cols
+        ]
+        arg_exprs = [col._expr for col in arg_cols]
+        data_type_str = (
+            self._returnType.json() if isinstance(self._returnType, DataType) 
else self._returnType
+        )
+        py_udf = PythonUDF(
+            output_type=data_type_str,
+            eval_type=self.evalType,
+            command=CloudPickleSerializer().dumps((self.func, 
self._returnType)),
+        )
+        return Column(
+            ScalarInlineUserDefinedFunction(
+                function_name=self._name,
+                deterministic=self.deterministic,
+                arguments=arg_exprs,
+                function=py_udf,
+            )
+        )
+
+    # This function is for improving the online help system in the interactive 
interpreter.
+    # For example, the built-in help / pydoc.help. It wraps the UDF with the 
docstring and
+    # argument annotation. (See: SPARK-19161)
+    def _wrapped(self) -> "UserDefinedFunctionLike":
+        """
+        Wrap this udf with a function and attach docstring from func
+        """
+
+        # It is possible for a callable instance without __name__ attribute 
or/and
+        # __module__ attribute to be wrapped here. For example, 
functools.partial. In this case,
+        # we should avoid wrapping the attributes from the wrapped function to 
the wrapper
+        # function. So, we take out these attribute names from the default 
names to set and
+        # then manually assign it after being wrapped.
+        assignments = tuple(
+            a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__" and a 
!= "__module__"
+        )
+
+        @functools.wraps(self.func, assigned=assignments)
+        def wrapper(*args: "ColumnOrName") -> Column:
+            return self(*args)
+
+        wrapper.__name__ = self._name
+        wrapper.__module__ = (
+            self.func.__module__
+            if hasattr(self.func, "__module__")
+            else self.func.__class__.__module__
+        )
+
+        wrapper.func = self.func  # type: ignore[attr-defined]
+        wrapper.returnType = self._returnType  # type: ignore[attr-defined]
+        wrapper.evalType = self.evalType  # type: ignore[attr-defined]
+        wrapper.deterministic = self.deterministic  # type: 
ignore[attr-defined]
+        wrapper.asNondeterministic = functools.wraps(  # type: 
ignore[attr-defined]
+            self.asNondeterministic
+        )(lambda: self.asNondeterministic()._wrapped())
+        wrapper._unwrapped = self  # type: ignore[attr-defined]
+        return wrapper  # type: ignore[return-value]
+
+    def asNondeterministic(self) -> "UserDefinedFunction":
+        """
+        Updates UserDefinedFunction to nondeterministic.
+        """
+        self.deterministic = False
+        return self
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2c025db0f36..3426f2bdaf6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -10043,6 +10043,7 @@ def udf(
     ...
 
 
+@try_remote_functions
 def udf(
     f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
     returnType: "DataTypeOrString" = StringType(),
@@ -10053,6 +10054,9 @@ def udf(
 
     .. versionadded:: 1.3.0
 
+    .. versionchanged:: 3.4.0
+        Support Spark Connect.
+
     Parameters
     ----------
     f : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index e61aeced30c..b74b1a9ee69 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -19,7 +19,7 @@ import tempfile
 
 from pyspark.errors import PySparkTypeError
 from pyspark.sql import SparkSession
-from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType
+from pyspark.sql.types import StringType, StructType, StructField, ArrayType, 
IntegerType
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import ReusedPySparkTestCase
@@ -2282,15 +2282,52 @@ class 
SparkConnectFunctionTests(SparkConnectFuncTestCase):
             ).toPandas(),
         )
 
+    def test_udf(self):
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT a, b, c FROM VALUES
+            (1, 1.0, 'x'), (2, 2.0, 'y'), (3, 3.0, 'z')
+            AS tab(a, b, c)
+            """
+        # +---+---+---+
+        # |  a|  b|  c|
+        # +---+---+---+
+        # |  1|1.0|  x|
+        # |  2|2.0|  y|
+        # |  3|3.0|  z|
+        # +---+---+---+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # as a normal function
+        self.assert_eq(
+            cdf.withColumn("A", CF.udf(lambda x: x + 1)(cdf.a)).toPandas(),
+            sdf.withColumn("A", SF.udf(lambda x: x + 1)(sdf.a)).toPandas(),
+        )
+
+        # as a decorator
+        @CF.udf(StringType())
+        def cfun(x):
+            return x + "a"
+
+        @SF.udf(StringType())
+        def sfun(x):
+            return x + "a"
+
+        self.assert_eq(
+            cdf.withColumn("A", cfun(cdf.c)).toPandas(),
+            sdf.withColumn("A", sfun(sdf.c)).toPandas(),
+        )
+
     def test_unsupported_functions(self):
         # SPARK-41928: Disable unsupported functions.
 
         from pyspark.sql.connect import functions as CF
 
-        for f in (
-            "udf",
-            "pandas_udf",
-        ):
+        for f in ("pandas_udf",):
             with self.assertRaises(NotImplementedError):
                 getattr(CF, f)()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to