Repository: spark Updated Branches: refs/heads/branch-2.4 5d98c3194 -> ffd036a6d
[SPARK-23672][PYTHON] Document support for nested return types in scalar with arrow udfs ## What changes were proposed in this pull request? Clarify docstring for Scalar functions ## How was this patch tested? Adds a unit test showing use similar to wordcount, there's existing unit test for array of floats as well. Closes #20908 from holdenk/SPARK-23672-document-support-for-nested-return-types-in-scalar-with-arrow-udfs. Authored-by: Holden Karau <hol...@pigscanfly.ca> Signed-off-by: Bryan Cutler <cutl...@gmail.com> (cherry picked from commit da5685b5bb9ee7daaeb4e8f99c488ebd50c7aac3) Signed-off-by: Bryan Cutler <cutl...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ffd036a6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ffd036a6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ffd036a6 Branch: refs/heads/branch-2.4 Commit: ffd036a6d13814ebcc332990be1e286939cc6abe Parents: 5d98c31 Author: Holden Karau <hol...@pigscanfly.ca> Authored: Mon Sep 10 11:01:51 2018 -0700 Committer: Bryan Cutler <cutl...@gmail.com> Committed: Mon Sep 10 11:02:09 2018 -0700 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 3 ++- python/pyspark/sql/tests.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ffd036a6/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9396b16..81f35f5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2720,9 +2720,10 @@ def pandas_udf(f=None, returnType=None, functionType=None): 1. SCALAR A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The returnType should be a primitive data type, e.g., :class:`DoubleType`. The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + :class:`MapType`, :class:`StructType` are currently not supported as output types. + Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and :meth:`pyspark.sql.DataFrame.select`. http://git-wip-us.apache.org/repos/asf/spark/blob/ffd036a6/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6d9d636..8e5bc67 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4443,6 +4443,7 @@ class ArrowTests(ReusedSQLTestCase): not _have_pandas or not _have_pyarrow, _pandas_requirement_message or _pyarrow_requirement_message) class PandasUDFTests(ReusedSQLTestCase): + def test_pandas_udf_basic(self): from pyspark.rdd import PythonEvalType from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4658,6 +4659,24 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): random_udf = random_udf.asNondeterministic() return random_udf + def test_pandas_udf_tokenize(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')), + ArrayType(StringType())) + self.assertEqual(tokenize.returnType, ArrayType(StringType())) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect()) + + def test_pandas_udf_nested_arrays(self): + from pyspark.sql.functions import pandas_udf + tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]), + ArrayType(ArrayType(StringType()))) + self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType()))) + df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) + result = df.select(tokenize("vals").alias("hi")) + self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect()) + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col, array df = self.spark.range(10).select( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org