Repository: spark Updated Branches: refs/heads/master cafca54c0 -> 63d90e7da
[SPARK-18777][PYTHON][SQL] Return UDF from udf.register ## What changes were proposed in this pull request? - Move udf wrapping code from `functions.udf` to `functions.UserDefinedFunction`. - Return wrapped udf from `catalog.registerFunction` and dependent methods. - Update docstrings in `catalog.registerFunction` and `SQLContext.registerFunction`. - Unit tests. ## How was this patch tested? - Existing unit tests and docstests. - Additional tests covering new feature. Author: zero323 <[email protected]> Closes #17831 from zero323/SPARK-18777. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63d90e7d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63d90e7d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63d90e7d Branch: refs/heads/master Commit: 63d90e7da4913917982c0501d63ccc433a9b6b46 Parents: cafca54 Author: zero323 <[email protected]> Authored: Sat May 6 22:28:42 2017 -0700 Committer: Xiao Li <[email protected]> Committed: Sat May 6 22:28:42 2017 -0700 ---------------------------------------------------------------------- python/pyspark/sql/catalog.py | 11 ++++++++--- python/pyspark/sql/context.py | 12 ++++++++---- python/pyspark/sql/functions.py | 23 ++++++++++++++--------- python/pyspark/sql/tests.py | 9 +++++++++ 4 files changed, 39 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/catalog.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 41e68a4..5f25dce 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -237,23 +237,28 @@ class Catalog(object): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = spark.catalog.registerFunction("stringLengthString", len) >>> spark.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> spark.catalog.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ udf = UserDefinedFunction(f, returnType, name) self._jsparkSession.udf().registerPython(name, udf._judf) + return udf._wrapped() @since(2.0) def isCached(self, tableName): http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index fdb7abb..5197a9e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -185,22 +185,26 @@ class SQLContext(object): :param name: name of the UDF :param f: python function :param returnType: a :class:`pyspark.sql.types.DataType` object + :return: a wrapped :class:`UserDefinedFunction` - >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) + >>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(stringLengthString(test)=u'4')] + >>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + >>> from pyspark.sql.types import IntegerType - >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] >>> from pyspark.sql.types import IntegerType - >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] """ - self.sparkSession.catalog.registerFunction(name, f, returnType) + return self.sparkSession.catalog.registerFunction(name, f, returnType) @ignore_unicode_prefix @since(2.1) http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae38..8b3487c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1917,6 +1917,19 @@ class UserDefinedFunction(object): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + @functools.wraps(self.func) + def wrapper(*args): + return self(*args) + + wrapper.func = self.func + wrapper.returnType = self.returnType + + return wrapper + @since(1.3) def udf(f=None, returnType=StringType()): @@ -1951,15 +1964,7 @@ def udf(f=None, returnType=StringType()): """ def _udf(f, returnType=StringType()): udf_obj = UserDefinedFunction(f, returnType) - - @functools.wraps(f) - def wrapper(*args): - return udf_obj(*args) - - wrapper.func = udf_obj.func - wrapper.returnType = udf_obj.returnType - - return wrapper + return udf_obj._wrapped() # decorator @udf, @udf() or @udf(dataType()) if f is None or isinstance(f, (str, DataType)): http://git-wip-us.apache.org/repos/asf/spark/blob/63d90e7d/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f644624..7983bc5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -436,6 +436,15 @@ class SQLTests(ReusedPySparkTestCase): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect() + ) + def test_wholefile_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
