Repository: spark
Updated Branches:
  refs/heads/master cafca54c0 -> 63d90e7da


[SPARK-18777][PYTHON][SQL] Return UDF from udf.register

## What changes were proposed in this pull request?

- Move udf wrapping code from `functions.udf` to 
`functions.UserDefinedFunction`.
- Return wrapped udf from `catalog.registerFunction` and dependent methods.
- Update docstrings in `catalog.registerFunction` and 
`SQLContext.registerFunction`.
- Unit tests.

## How was this patch tested?

- Existing unit tests and docstests.
- Additional tests covering new feature.

Author: zero323 <[email protected]>

Closes #17831 from zero323/SPARK-18777.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63d90e7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63d90e7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63d90e7d

Branch: refs/heads/master
Commit: 63d90e7da4913917982c0501d63ccc433a9b6b46
Parents: cafca54
Author: zero323 <[email protected]>
Authored: Sat May 6 22:28:42 2017 -0700
Committer: Xiao Li <[email protected]>
Committed: Sat May 6 22:28:42 2017 -0700

----------------------------------------------------------------------
 python/pyspark/sql/catalog.py   | 11 ++++++++---
 python/pyspark/sql/context.py   | 12 ++++++++----
 python/pyspark/sql/functions.py | 23 ++++++++++++++---------
 python/pyspark/sql/tests.py     |  9 +++++++++
 4 files changed, 39 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/catalog.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 41e68a4..5f25dce 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -237,23 +237,28 @@ class Catalog(object):
         :param name: name of the UDF
         :param f: python function
         :param returnType: a :class:`pyspark.sql.types.DataType` object
+        :return: a wrapped :class:`UserDefinedFunction`
 
-        >>> spark.catalog.registerFunction("stringLengthString", lambda x: 
len(x))
+        >>> strlen = spark.catalog.registerFunction("stringLengthString", len)
         >>> spark.sql("SELECT stringLengthString('test')").collect()
         [Row(stringLengthString(test)=u'4')]
 
+        >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
+        [Row(stringLengthString(text)=u'3')]
+
         >>> from pyspark.sql.types import IntegerType
-        >>> spark.catalog.registerFunction("stringLengthInt", lambda x: 
len(x), IntegerType())
+        >>> _ = spark.catalog.registerFunction("stringLengthInt", len, 
IntegerType())
         >>> spark.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
 
         >>> from pyspark.sql.types import IntegerType
-        >>> spark.udf.register("stringLengthInt", lambda x: len(x), 
IntegerType())
+        >>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
         >>> spark.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
         """
         udf = UserDefinedFunction(f, returnType, name)
         self._jsparkSession.udf().registerPython(name, udf._judf)
+        return udf._wrapped()
 
     @since(2.0)
     def isCached(self, tableName):

http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index fdb7abb..5197a9e 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -185,22 +185,26 @@ class SQLContext(object):
         :param name: name of the UDF
         :param f: python function
         :param returnType: a :class:`pyspark.sql.types.DataType` object
+        :return: a wrapped :class:`UserDefinedFunction`
 
-        >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
+        >>> strlen = sqlContext.registerFunction("stringLengthString", lambda 
x: len(x))
         >>> sqlContext.sql("SELECT stringLengthString('test')").collect()
         [Row(stringLengthString(test)=u'4')]
 
+        >>> sqlContext.sql("SELECT 'foo' AS 
text").select(strlen("text")).collect()
+        [Row(stringLengthString(text)=u'3')]
+
         >>> from pyspark.sql.types import IntegerType
-        >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), 
IntegerType())
+        >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: 
len(x), IntegerType())
         >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
 
         >>> from pyspark.sql.types import IntegerType
-        >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), 
IntegerType())
+        >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), 
IntegerType())
         >>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
         [Row(stringLengthInt(test)=4)]
         """
-        self.sparkSession.catalog.registerFunction(name, f, returnType)
+        return self.sparkSession.catalog.registerFunction(name, f, returnType)
 
     @ignore_unicode_prefix
     @since(2.1)

http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 843ae38..8b3487c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1917,6 +1917,19 @@ class UserDefinedFunction(object):
         sc = SparkContext._active_spark_context
         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
 
+    def _wrapped(self):
+        """
+        Wrap this udf with a function and attach docstring from func
+        """
+        @functools.wraps(self.func)
+        def wrapper(*args):
+            return self(*args)
+
+        wrapper.func = self.func
+        wrapper.returnType = self.returnType
+
+        return wrapper
+
 
 @since(1.3)
 def udf(f=None, returnType=StringType()):
@@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()):
     """
     def _udf(f, returnType=StringType()):
         udf_obj = UserDefinedFunction(f, returnType)
-
-        @functools.wraps(f)
-        def wrapper(*args):
-            return udf_obj(*args)
-
-        wrapper.func = udf_obj.func
-        wrapper.returnType = udf_obj.returnType
-
-        return wrapper
+        return udf_obj._wrapped()
 
     # decorator @udf, @udf() or @udf(dataType())
     if f is None or isinstance(f, (str, DataType)):

http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index f644624..7983bc5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -436,6 +436,15 @@ class SQLTests(ReusedPySparkTestCase):
         res.explain(True)
         self.assertEqual(res.collect(), [Row(id=0, copy=0)])
 
+    def test_udf_registration_returns_udf(self):
+        df = self.spark.range(10)
+        add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
+
+        self.assertListEqual(
+            df.selectExpr("add_three(id) AS plus_three").collect(),
+            df.select(add_three("id").alias("plus_three")).collect()
+        )
+
     def test_wholefile_json(self):
         people1 = self.spark.read.json("python/test_support/sql/people.json")
         people_array = 
self.spark.read.json("python/test_support/sql/people_array.json",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to