Repository: spark Updated Branches: refs/heads/master 568763baf -> b8624b06e
[SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf ## What changes were proposed in this pull request? This is a follow-up of #18732. This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped udf implicitly. ## How was this patch tested? Exisiting tests. Author: Takuya UESHIN <[email protected]> Closes #19517 from ueshin/issues/SPARK-20396/fup2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8624b06 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8624b06 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8624b06 Branch: refs/heads/master Commit: b8624b06e5d531ebc14acb05da286f96f4bc9515 Parents: 568763b Author: Takuya UESHIN <[email protected]> Authored: Fri Oct 20 12:44:30 2017 -0700 Committer: gatorsmile <[email protected]> Committed: Fri Oct 20 12:44:30 2017 -0700 ---------------------------------------------------------------------- .../apache/spark/api/python/PythonRunner.scala | 1 + python/pyspark/serializers.py | 1 + python/pyspark/sql/functions.py | 33 +++++++++++------ python/pyspark/sql/group.py | 14 ++++--- python/pyspark/sql/tests.py | 37 +++++++++++++++++++ python/pyspark/worker.py | 39 +++++++++----------- .../plans/logical/pythonLogicalOperators.scala | 9 +++-- .../spark/sql/RelationalGroupedDataset.scala | 7 ++-- .../execution/python/ExtractPythonUDFs.scala | 6 ++- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../spark/sql/execution/python/PythonUDF.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 13 ++++++- .../python/BatchEvalPythonExecSuite.scala | 2 +- 13 files changed, 114 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3688a14..d417303 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 } /** http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0..a0adeed 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -86,6 +86,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9bc12c3..9bc374b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-at-a-time UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2067,7 @@ class UserDefinedFunction(object): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self.vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2099,7 @@ class UserDefinedFunction(object): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2121,15 +2130,15 @@ class UserDefinedFunction(object): wrapper.func = self.func wrapper.returnType = self.returnType - wrapper.vectorized = self.vectorized + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect argspec = inspect.getargspec(f) if len(argspec.args) == 0 and argspec.varargs is None: @@ -2137,7 +2146,7 @@ def _create_udf(f, returnType, vectorized): "0-arg pandas_udfs are not supported. " "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf @@ -2145,9 +2154,9 @@ def _create_udf(f, returnType, vectorized): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) @@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) @@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()): .. note:: The user-defined function must be deterministic. """ - return _create_udf(f, returnType=returnType, vectorized=True) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/group.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 817d0bc..e11388d 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType, UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -235,11 +236,13 @@ class GroupedData(object): .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ - from pyspark.sql.functions import pandas_udf + import inspect # Columns are special because hasattr always return True - if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized: - raise ValueError("The argument to apply must be a pandas_udf") + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ + or len(inspect.getargspec(udf.func).args) != 1: + raise ValueError("The argument to apply must be a 1-arg pandas_udf") if not isinstance(udf.returnType, StructType): raise ValueError("The returnType of the pandas_udf must be a StructType") @@ -268,8 +271,9 @@ class GroupedData(object): return [(result[result.columns[i]], arrow_type) for i, arrow_type in enumerate(arrow_return_types)] - wrapped_udf_obj = pandas_udf(wrapped, returnType) - udf_column = wrapped_udf_obj(*[df[col] for col in df.columns]) + udf_obj = UserDefinedFunction( + wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) + udf_column = udf_obj(*[df[col] for col in df.columns]) jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bac2ef8..685eebc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3383,6 +3383,15 @@ class VectorizedUDFTests(ReusedPySparkTestCase): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): @@ -3492,6 +3501,18 @@ class GroupbyApplyTests(ReusedPySparkTestCase): expected = expected.assign(norm=expected.norm.astype('float64')) self.assertFramesEqual(expected, result) + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf df = self.data @@ -3517,9 +3538,25 @@ class GroupbyApplyTests(ReusedPySparkTestCase): df.groupby('id').apply(sum(df.v)) with self.assertRaisesRegexp(ValueError, 'pandas_udf'): df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) with self.assertRaisesRegexp(ValueError, 'returnType'): df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() + if __name__ == "__main__": from pyspark.sql.tests import * http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index eb6d486..5e100e0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,7 +32,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type, StructType +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,28 +74,19 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - # If the return_type is a StructType, it indicates this is a groupby apply udf, - # and has already been wrapped under apply(), otherwise, it's a vectorized column udf. - # We can distinguish these two by return type because in groupby apply, we always specify - # returnType as a StructType, and in vectorized column udf, StructType is not supported. - # - # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs - if isinstance(return_type, StructType): - return lambda *a: f(*a) - else: - arrow_return_type = to_arrow_type(return_type) + arrow_return_type = to_arrow_type(return_type) - def verify_result_length(*a): - result = f(*a) - if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " - "Pandas.Series, but is {}".format(type(result))) - if len(result) != len(a[0]): - raise RuntimeError("Result vector from pandas_udf was not the required length: " - "expected %d, got %d" % (len(a[0]), len(result))) - return result + def verify_result_length(*a): + result = f(*a) + if not hasattr(result, "__len__"): + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) + if len(result) != len(a[0]): + raise RuntimeError("Result vector from pandas_udf was not the required length: " + "expected %d, got %d" % (len(a[0]), len(result))) + return result - return lambda *a: (verify_result_length(*a), arrow_return_type) + return lambda *a: (verify_result_length(*a), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type): @@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8abab24..254687e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre * This is used by DataFrame.groupby().apply(). */ case class FlatMapGroupsInPandas( - groupingAttributes: Seq[Attribute], - functionExpr: Expression, - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + /** * This is needed because output attributes are considered `references` when * passed through the constructor. http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 33ec3a2..6b45790 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.python.PythonUDF +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{NumericType, StructType} @@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * Applies a vectorized python user-defined function to each group of data. + * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results * for all groups are combined into a new [[DataFrame]]. @@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql]( * workers. */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { - require(expr.vectorized, "Must pass a vectorized python udf") + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") require(expr.dataType.isInstanceOf[StructType], "The returnType of the vectorized python udf must be a StructType") http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index e3f952e..d682536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5b..5ed88ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e..9c07c76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80a..b2fe6c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-at-a-time UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ http://git-wip-us.apache.org/repos/asf/spark/blob/b8624b06/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1..95b21fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
