Repository: spark
Updated Branches:
  refs/heads/master 568763baf -> b8624b06e


[SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf

## What changes were proposed in this pull request?

This is a follow-up of #18732.
This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped 
udf implicitly.

## How was this patch tested?

Exisiting tests.

Author: Takuya UESHIN <[email protected]>

Closes #19517 from ueshin/issues/SPARK-20396/fup2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8624b06
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8624b06
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8624b06

Branch: refs/heads/master
Commit: b8624b06e5d531ebc14acb05da286f96f4bc9515
Parents: 568763b
Author: Takuya UESHIN <[email protected]>
Authored: Fri Oct 20 12:44:30 2017 -0700
Committer: gatorsmile <[email protected]>
Committed: Fri Oct 20 12:44:30 2017 -0700

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonRunner.scala  |  1 +
 python/pyspark/serializers.py                   |  1 +
 python/pyspark/sql/functions.py                 | 33 +++++++++++------
 python/pyspark/sql/group.py                     | 14 ++++---
 python/pyspark/sql/tests.py                     | 37 +++++++++++++++++++
 python/pyspark/worker.py                        | 39 +++++++++-----------
 .../plans/logical/pythonLogicalOperators.scala  |  9 +++--
 .../spark/sql/RelationalGroupedDataset.scala    |  7 ++--
 .../execution/python/ExtractPythonUDFs.scala    |  6 ++-
 .../python/FlatMapGroupsInPandasExec.scala      |  2 +-
 .../spark/sql/execution/python/PythonUDF.scala  |  2 +-
 .../python/UserDefinedPythonFunction.scala      | 13 ++++++-
 .../python/BatchEvalPythonExecSuite.scala       |  2 +-
 13 files changed, 114 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 3688a14..d417303 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -36,6 +36,7 @@ private[spark] object PythonEvalType {
   val NON_UDF = 0
   val SQL_BATCHED_UDF = 1
   val SQL_PANDAS_UDF = 2
+  val SQL_PANDAS_GROUPED_UDF = 3
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ad18bd0..a0adeed 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -86,6 +86,7 @@ class PythonEvalType(object):
     NON_UDF = 0
     SQL_BATCHED_UDF = 1
     SQL_PANDAS_UDF = 2
+    SQL_PANDAS_GROUPED_UDF = 3
 
 
 class Serializer(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9bc12c3..9bc374b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType):
                                   sc.pythonVer, broadcast_vars, 
sc._javaAccumulator)
 
 
+class PythonUdfType(object):
+    # row-at-a-time UDFs
+    NORMAL_UDF = 0
+    # scalar vectorized UDFs
+    PANDAS_UDF = 1
+    # grouped vectorized UDFs
+    PANDAS_GROUPED_UDF = 2
+
+
 class UserDefinedFunction(object):
     """
     User defined function in Python
 
     .. versionadded:: 1.3
     """
-    def __init__(self, func, returnType, name=None, vectorized=False):
+    def __init__(self, func, returnType, name=None, 
pythonUdfType=PythonUdfType.NORMAL_UDF):
         if not callable(func):
             raise TypeError(
                 "Not a function or callable (__call__ is not defined): "
@@ -2058,7 +2067,7 @@ class UserDefinedFunction(object):
         self._name = name or (
             func.__name__ if hasattr(func, '__name__')
             else func.__class__.__name__)
-        self.vectorized = vectorized
+        self.pythonUdfType = pythonUdfType
 
     @property
     def returnType(self):
@@ -2090,7 +2099,7 @@ class UserDefinedFunction(object):
         wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            self._name, wrapped_func, jdt, self.vectorized)
+            self._name, wrapped_func, jdt, self.pythonUdfType)
         return judf
 
     def __call__(self, *cols):
@@ -2121,15 +2130,15 @@ class UserDefinedFunction(object):
 
         wrapper.func = self.func
         wrapper.returnType = self.returnType
-        wrapper.vectorized = self.vectorized
+        wrapper.pythonUdfType = self.pythonUdfType
 
         return wrapper
 
 
-def _create_udf(f, returnType, vectorized):
+def _create_udf(f, returnType, pythonUdfType):
 
-    def _udf(f, returnType=StringType(), vectorized=vectorized):
-        if vectorized:
+    def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
+        if pythonUdfType == PythonUdfType.PANDAS_UDF:
             import inspect
             argspec = inspect.getargspec(f)
             if len(argspec.args) == 0 and argspec.varargs is None:
@@ -2137,7 +2146,7 @@ def _create_udf(f, returnType, vectorized):
                     "0-arg pandas_udfs are not supported. "
                     "Instead, create a 1-arg pandas_udf and ignore the arg in 
your function."
                 )
-        udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
+        udf_obj = UserDefinedFunction(f, returnType, 
pythonUdfType=pythonUdfType)
         return udf_obj._wrapped()
 
     # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
@@ -2145,9 +2154,9 @@ def _create_udf(f, returnType, vectorized):
         # If DataType has been passed as a positional argument
         # for decorator use it as a returnType
         return_type = f or returnType
-        return functools.partial(_udf, returnType=return_type, 
vectorized=vectorized)
+        return functools.partial(_udf, returnType=return_type, 
pythonUdfType=pythonUdfType)
     else:
-        return _udf(f=f, returnType=returnType, vectorized=vectorized)
+        return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType)
 
 
 @since(1.3)
@@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()):
     |         8|      JOHN DOE|          22|
     +----------+--------------+------------+
     """
-    return _create_udf(f, returnType=returnType, vectorized=False)
+    return _create_udf(f, returnType=returnType, 
pythonUdfType=PythonUdfType.NORMAL_UDF)
 
 
 @since(2.3)
@@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()):
 
     .. note:: The user-defined function must be deterministic.
     """
-    return _create_udf(f, returnType=returnType, vectorized=True)
+    return _create_udf(f, returnType=returnType, 
pythonUdfType=PythonUdfType.PANDAS_UDF)
 
 
 blacklist = ['map', 'since', 'ignore_unicode_prefix']

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 817d0bc..e11388d 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -19,6 +19,7 @@ from pyspark import since
 from pyspark.rdd import ignore_unicode_prefix
 from pyspark.sql.column import Column, _to_seq, _to_java_column, 
_create_column_from_literal
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.functions import PythonUdfType, UserDefinedFunction
 from pyspark.sql.types import *
 
 __all__ = ["GroupedData"]
@@ -235,11 +236,13 @@ class GroupedData(object):
         .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
 
         """
-        from pyspark.sql.functions import pandas_udf
+        import inspect
 
         # Columns are special because hasattr always return True
-        if isinstance(udf, Column) or not hasattr(udf, 'func') or not 
udf.vectorized:
-            raise ValueError("The argument to apply must be a pandas_udf")
+        if isinstance(udf, Column) or not hasattr(udf, 'func') \
+           or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \
+           or len(inspect.getargspec(udf.func).args) != 1:
+            raise ValueError("The argument to apply must be a 1-arg 
pandas_udf")
         if not isinstance(udf.returnType, StructType):
             raise ValueError("The returnType of the pandas_udf must be a 
StructType")
 
@@ -268,8 +271,9 @@ class GroupedData(object):
             return [(result[result.columns[i]], arrow_type)
                     for i, arrow_type in enumerate(arrow_return_types)]
 
-        wrapped_udf_obj = pandas_udf(wrapped, returnType)
-        udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
+        udf_obj = UserDefinedFunction(
+            wrapped, returnType, name=udf.__name__, 
pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
+        udf_column = udf_obj(*[df[col] for col in df.columns])
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.sql_ctx)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index bac2ef8..685eebc 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3383,6 +3383,15 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
         res = df.select(f(col('id')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_unsupported_types(self):
+        from pyspark.sql.functions import pandas_udf, col
+        schema = StructType([StructField("dt", DateType(), True)])
+        df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], 
schema=schema)
+        f = pandas_udf(lambda x: x, DateType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
+                df.select(f(col('dt'))).collect()
+
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not 
installed")
 class GroupbyApplyTests(ReusedPySparkTestCase):
@@ -3492,6 +3501,18 @@ class GroupbyApplyTests(ReusedPySparkTestCase):
         expected = expected.assign(norm=expected.norm.astype('float64'))
         self.assertFramesEqual(expected, result)
 
+    def test_datatype_string(self):
+        from pyspark.sql.functions import pandas_udf
+        df = self.data
+
+        foo_udf = pandas_udf(
+            lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+            "id long, v int, v1 double, v2 long")
+
+        result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
+        expected = 
df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
+        self.assertFramesEqual(expected, result)
+
     def test_wrong_return_type(self):
         from pyspark.sql.functions import pandas_udf
         df = self.data
@@ -3517,9 +3538,25 @@ class GroupbyApplyTests(ReusedPySparkTestCase):
                 df.groupby('id').apply(sum(df.v))
             with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
                 df.groupby('id').apply(df.v + 1)
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(
+                    pandas_udf(lambda: 1, StructType([StructField("d", 
DoubleType())])))
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(
+                    pandas_udf(lambda x, y: x, StructType([StructField("d", 
DoubleType())])))
             with self.assertRaisesRegexp(ValueError, 'returnType'):
                 df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType()))
 
+    def test_unsupported_types(self):
+        from pyspark.sql.functions import pandas_udf, col
+        schema = StructType(
+            [StructField("id", LongType(), True), StructField("dt", 
DateType(), True)])
+        df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], 
schema=schema)
+        f = pandas_udf(lambda x: x, df.schema)
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
+                df.groupby('id').apply(f).collect()
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests import *

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index eb6d486..5e100e0 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -32,7 +32,7 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, 
PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
-from pyspark.sql.types import to_arrow_type, StructType
+from pyspark.sql.types import to_arrow_type
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -74,28 +74,19 @@ def wrap_udf(f, return_type):
 
 
 def wrap_pandas_udf(f, return_type):
-    # If the return_type is a StructType, it indicates this is a groupby apply 
udf,
-    # and has already been wrapped under apply(), otherwise, it's a vectorized 
column udf.
-    # We can distinguish these two by return type because in groupby apply, we 
always specify
-    # returnType as a StructType, and in vectorized column udf, StructType is 
not supported.
-    #
-    # TODO: Look into refactoring use of StructType to be more flexible for 
future pandas_udfs
-    if isinstance(return_type, StructType):
-        return lambda *a: f(*a)
-    else:
-        arrow_return_type = to_arrow_type(return_type)
+    arrow_return_type = to_arrow_type(return_type)
 
-        def verify_result_length(*a):
-            result = f(*a)
-            if not hasattr(result, "__len__"):
-                raise TypeError("Return type of the user-defined functon 
should be "
-                                "Pandas.Series, but is 
{}".format(type(result)))
-            if len(result) != len(a[0]):
-                raise RuntimeError("Result vector from pandas_udf was not the 
required length: "
-                                   "expected %d, got %d" % (len(a[0]), 
len(result)))
-            return result
+    def verify_result_length(*a):
+        result = f(*a)
+        if not hasattr(result, "__len__"):
+            raise TypeError("Return type of the user-defined functon should be 
"
+                            "Pandas.Series, but is {}".format(type(result)))
+        if len(result) != len(a[0]):
+            raise RuntimeError("Result vector from pandas_udf was not the 
required length: "
+                               "expected %d, got %d" % (len(a[0]), 
len(result)))
+        return result
 
-        return lambda *a: (verify_result_length(*a), arrow_return_type)
+    return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
 def read_single_udf(pickleSer, infile, eval_type):
@@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type):
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_PANDAS_UDF:
         return arg_offsets, wrap_pandas_udf(row_func, return_type)
+    elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
+        # a groupby apply udf has already been wrapped under apply()
+        return arg_offsets, row_func
     else:
         return arg_offsets, wrap_udf(row_func, return_type)
 
@@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
     func = lambda _, it: map(mapper, it)
 
-    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+    if eval_type == PythonEvalType.SQL_PANDAS_UDF \
+       or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
         ser = ArrowStreamPandasSerializer()
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 8abab24..254687e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -24,10 +24,11 @@ import 
org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre
  * This is used by DataFrame.groupby().apply().
  */
 case class FlatMapGroupsInPandas(
-  groupingAttributes: Seq[Attribute],
-  functionExpr: Expression,
-  output: Seq[Attribute],
-  child: LogicalPlan) extends UnaryNode {
+    groupingAttributes: Seq[Attribute],
+    functionExpr: Expression,
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+
   /**
    * This is needed because output attributes are considered `references` when
    * passed through the constructor.

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 33ec3a2..6b45790 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-import org.apache.spark.sql.execution.python.PythonUDF
+import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{NumericType, StructType}
 
@@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql](
   }
 
   /**
-   * Applies a vectorized python user-defined function to each group of data.
+   * Applies a grouped vectorized python user-defined function to each group 
of data.
    * The user-defined function defines a transformation: `pandas.DataFrame` -> 
`pandas.DataFrame`.
    * For each group, all elements in the group are passed as a 
`pandas.DataFrame` and the results
    * for all groups are combined into a new [[DataFrame]].
@@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql](
    * workers.
    */
   private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
-    require(expr.vectorized, "Must pass a vectorized python udf")
+    require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF,
+      "Must pass a grouped vectorized python udf")
     require(expr.dataType.isInstanceOf[StructType],
       "The returnType of the vectorized python udf must be a StructType")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index e3f952e..d682536 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with 
PredicateHelper {
           udf.references.subsetOf(child.outputSet)
         }
         if (validUdfs.nonEmpty) {
+          if (validUdfs.exists(_.pythonUdfType == 
PythonUdfType.PANDAS_GROUPED_UDF)) {
+            throw new IllegalArgumentException("Can not use grouped vectorized 
UDFs")
+          }
+
           val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
             AttributeReference(s"pythonUDF$i", u.dataType)()
           }
 
-          val evaluation = validUdfs.partition(_.vectorized) match {
+          val evaluation = validUdfs.partition(_.pythonUdfType == 
PythonUdfType.PANDAS_UDF) match {
             case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
               ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, 
child)
             case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index b996b5b..5ed88ad 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec(
 
       val columnarBatchIter = new ArrowPythonRunner(
         chainedFunc, bufferSize, reuseWorker,
-        PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
+        PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema)
         .compute(grouped, context.partitionId(), context)
 
       
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output,
 output))

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 84a6d9e..9c07c76 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -29,7 +29,7 @@ case class PythonUDF(
     func: PythonFunction,
     dataType: DataType,
     children: Seq[Expression],
-    vectorized: Boolean)
+    pythonUdfType: Int)
   extends Expression with Unevaluable with NonSQLExpression with 
UserDefinedExpression {
 
   override def toString: String = s"$name(${children.mkString(", ")})"

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index a30a80a..b2fe6c3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -22,6 +22,15 @@ import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.types.DataType
 
+private[spark] object PythonUdfType {
+  // row-at-a-time UDFs
+  val NORMAL_UDF = 0
+  // scalar vectorized UDFs
+  val PANDAS_UDF = 1
+  // grouped vectorized UDFs
+  val PANDAS_GROUPED_UDF = 2
+}
+
 /**
  * A user-defined Python function. This is used by the Python API.
  */
@@ -29,10 +38,10 @@ case class UserDefinedPythonFunction(
     name: String,
     func: PythonFunction,
     dataType: DataType,
-    vectorized: Boolean) {
+    pythonUdfType: Int) {
 
   def builder(e: Seq[Expression]): PythonUDF = {
-    PythonUDF(name, func, dataType, e, vectorized)
+    PythonUDF(name, func, dataType, e, pythonUdfType)
   }
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the 
given input. */

http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 153e6e1..95b21fc 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
   name = "dummyUDF",
   func = new DummyUDF,
   dataType = BooleanType,
-  vectorized = false)
+  pythonUdfType = PythonUdfType.NORMAL_UDF)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to