Repository: spark
Updated Branches:
  refs/heads/master e58c4cb3c -> a7a93a116


[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs

## What changes were proposed in this pull request?

This PR brings the support for chained Python UDFs, for example

```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```

Also directly chained unary Python UDFs are put in single batch of Python UDFs, 
others may require multiple batches.

For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF#10 AS double(double(1))#9]
:     +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
   +- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
:     +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), 
[pythonUDF#17,pythonUDF#18,pythonUDF#19]
   +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
      +- !BatchPythonEvaluation double(1), [pythonUDF#17]
         +- Scan OneRowRelation[]
```

TODO: will support multiple unrelated Python UDFs in one batch (another PR).

## How was this patch tested?

Added new unit tests for chained UDFs.

Author: Davies Liu <[email protected]>

Closes #12014 from davies/py_udfs.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7a93a11
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7a93a11
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7a93a11

Branch: refs/heads/master
Commit: a7a93a116dd9813853ba6f112beb7763931d2006
Parents: e58c4cb
Author: Davies Liu <[email protected]>
Authored: Tue Mar 29 15:06:29 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Tue Mar 29 15:06:29 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 38 +++++++++++---------
 python/pyspark/sql/functions.py                 | 16 ++++++---
 python/pyspark/sql/tests.py                     |  9 +++++
 python/pyspark/worker.py                        | 33 ++++++++++++++---
 .../python/BatchPythonEvaluation.scala          | 29 ++++++++++-----
 .../execution/python/ExtractPythonUDFs.scala    | 26 +++++++++++++-
 6 files changed, 116 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f423b2e..0f579b4 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
   val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 
   override def compute(split: Partition, context: TaskContext): 
Iterator[Array[Byte]] = {
-    val runner = new PythonRunner(func, bufferSize, reuse_worker)
+    val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
     runner.compute(firstParent.iterator(split, context), split.index, context)
   }
 }
@@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
  * A helper class to run Python UDFs in Spark.
  */
 private[spark] class PythonRunner(
-    func: PythonFunction,
+    funcs: Seq[PythonFunction],
     bufferSize: Int,
-    reuse_worker: Boolean)
+    reuse_worker: Boolean,
+    rowBased: Boolean)
   extends Logging {
 
-  private val envVars = func.envVars
-  private val pythonExec = func.pythonExec
-  private val accumulator = func.accumulator
+  // All the Python functions should have the same exec, version and envvars.
+  private val envVars = funcs.head.envVars
+  private val pythonExec = funcs.head.pythonExec
+  private val pythonVer = funcs.head.pythonVer
+
+  private val accumulator = funcs.head.accumulator // TODO: support 
accumulator in multiple UDF
 
   def compute(
       inputIterator: Iterator[_],
@@ -228,10 +232,8 @@ private[spark] class PythonRunner(
 
     @volatile private var _exception: Exception = null
 
-    private val pythonVer = func.pythonVer
-    private val pythonIncludes = func.pythonIncludes
-    private val broadcastVars = func.broadcastVars
-    private val command = func.command
+    private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
+    private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
 
     setDaemon(true)
 
@@ -256,13 +258,13 @@ private[spark] class PythonRunner(
         // sparkFilesDir
         PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
         // Python includes (*.zip and *.egg files)
-        dataOut.writeInt(pythonIncludes.size())
-        for (include <- pythonIncludes.asScala) {
+        dataOut.writeInt(pythonIncludes.size)
+        for (include <- pythonIncludes) {
           PythonRDD.writeUTF(include, dataOut)
         }
         // Broadcast variables
         val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-        val newBids = broadcastVars.asScala.map(_.id).toSet
+        val newBids = broadcastVars.map(_.id).toSet
         // number of different broadcasts
         val toRemove = oldBids.diff(newBids)
         val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +274,7 @@ private[spark] class PythonRunner(
           dataOut.writeLong(- bid - 1)  // bid >= 0
           oldBids.remove(bid)
         }
-        for (broadcast <- broadcastVars.asScala) {
+        for (broadcast <- broadcastVars) {
           if (!oldBids.contains(broadcast.id)) {
             // send new broadcast
             dataOut.writeLong(broadcast.id)
@@ -282,8 +284,12 @@ private[spark] class PythonRunner(
         }
         dataOut.flush()
         // Serialized command:
-        dataOut.writeInt(command.length)
-        dataOut.write(command)
+        dataOut.writeInt(if (rowBased) 1 else 0)
+        dataOut.writeInt(funcs.length)
+        funcs.foreach { f =>
+          dataOut.writeInt(f.command.length)
+          dataOut.write(f.command)
+        }
         // Data values
         PythonRDD.writeIteratorToStream(inputIterator, dataOut)
         dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f5d959e..3211834 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
     from itertools import imap as map
 
 from pyspark import since, SparkContext
-from pyspark.rdd import _wrap_function, ignore_unicode_prefix
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
 from pyspark.sql.types import StringType
 from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):
 
 # ---------------------------- User Defined Function 
----------------------------------
 
+def _wrap_function(sc, func, returnType):
+    ser = AutoBatchedSerializer(PickleSerializer())
+    command = (func, returnType, ser)
+    pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command)
+    return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, 
sc.pythonExec,
+                                  sc.pythonVer, broadcast_vars, 
sc._javaAccumulator)
+
+
 class UserDefinedFunction(object):
     """
     User defined function in Python
@@ -1662,14 +1670,12 @@ class UserDefinedFunction(object):
 
     def _create_judf(self, name):
         from pyspark.sql import SQLContext
-        f, returnType = self.func, self.returnType  # put them in closure 
`func`
-        func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
-        ser = AutoBatchedSerializer(PickleSerializer())
         sc = SparkContext.getOrCreate()
-        wrapped_func = _wrap_function(sc, func, ser, ser)
+        wrapped_func = _wrap_function(sc, self.func, self.returnType)
         ctx = SQLContext.getOrCreate(sc)
         jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
         if name is None:
+            f = self.func
             name = f.__name__ if hasattr(f, '__name__') else 
f.__class__.__name__
         judf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
             name, wrapped_func, jdt)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1a5d422..8494756 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase):
         [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 
1").collect()
         self.assertEqual(4, res[0])
 
+    def test_chained_python_udf(self):
+        self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
+        [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+        self.assertEqual(row[0], 2)
+        [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+        self.assertEqual(row[0], 4)
+        [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+        self.assertEqual(row[0], 6)
+
     def test_udf_with_array_type(self):
         d = [Row(l=list(range(3)), d={"key": list(range(5))})]
         rdd = self.sc.parallelize(d)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 42c2f8b..0f05fe3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -50,6 +50,18 @@ def add_path(path):
         sys.path.insert(1, path)
 
 
+def read_command(serializer, file):
+    command = serializer._read_with_length(file)
+    if isinstance(command, Broadcast):
+        command = serializer.loads(command.value)
+    return command
+
+
+def chain(f, g):
+    """chain two function together """
+    return lambda x: g(f(x))
+
+
 def main(infile, outfile):
     try:
         boot_time = time.time()
@@ -95,10 +107,23 @@ def main(infile, outfile):
                 _broadcastRegistry.pop(bid)
 
         _accumulatorRegistry.clear()
-        command = pickleSer._read_with_length(infile)
-        if isinstance(command, Broadcast):
-            command = pickleSer.loads(command.value)
-        func, profiler, deserializer, serializer = command
+        row_based = read_int(infile)
+        num_commands = read_int(infile)
+        if row_based:
+            profiler = None  # profiling is not supported for UDF
+            row_func = None
+            for i in range(num_commands):
+                f, returnType, deserializer = read_command(pickleSer, infile)
+                if row_func is None:
+                    row_func = f
+                else:
+                    row_func = chain(row_func, f)
+            serializer = deserializer
+            func = lambda _, it: map(lambda x: 
returnType.toInternal(row_func(*x)), it)
+        else:
+            assert num_commands == 1
+            func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
+
         init_time = time.time()
 
         def process():

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 79e4491..a76009e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
 import net.razorvine.pickle.{Pickler, Unpickler}
 
 import org.apache.spark.TaskContext
-import org.apache.spark.api.python.PythonRunner
+import org.apache.spark.api.python.{PythonFunction, PythonRunner}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
GenericMutableRow, JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.types.{StructField, StructType}
 
@@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: 
Seq[Attribute], child:
 
   def children: Seq[SparkPlan] = child :: Nil
 
+  private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], 
Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (fs, children) = collectFunctions(u)
+        (fs ++ Seq(udf.func), children)
+      case children =>
+        // There should not be any other UDFs, or the children can't be 
evaluated directly.
+        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+        (Seq(udf.func), udf.children)
+    }
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val inputRDD = child.execute().map(_.copy())
     val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: 
Seq[Attribute], child:
       // combine input with output from Python.
       val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
 
+      val (pyFuncs, children) = collectFunctions(udf)
+
       val pickle = new Pickler
-      val currentRow = newMutableProjection(udf.children, child.output)()
-      val fields = udf.children.map(_.dataType)
+      val currentRow = newMutableProjection(children, child.output)()
+      val fields = children.map(_.dataType)
       val schema = new StructType(fields.map(t => new StructField("", t, 
true)).toArray)
 
       // Input iterator to Python: input rows are grouped so we send them in 
batches to Python.
@@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: 
Seq[Attribute], child:
       val context = TaskContext.get()
 
       // Output iterator for results from Python.
-      val outputIterator = new PythonRunner(
-        udf.func,
-        bufferSize,
-        reuseWorker
-      ).compute(inputIterator, context.partitionId(), context)
+      val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, 
true)
+        .compute(inputIterator, context.partitionId(), context)
 
       val unpickle = new Unpickler
       val row = new GenericMutableRow(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/a7a93a11/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 6e76e95..c486ce1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.python
 
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule
  * Extracts PythonUDFs from operators, rewriting the query plan so that the 
UDF can be evaluated
  * alone in a batch.
  *
+ * Only extracts the PythonUDFs that could be evaluated in Python (the single 
child is PythonUDFs
+ * or all the children could be evaluated in JVM).
+ *
  * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
  * multiple child operators.
  */
 private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+
+  private def hasPythonUDF(e: Expression): Boolean = {
+    e.find(_.isInstanceOf[PythonUDF]).isDefined
+  }
+
+  private def canEvaluateInPython(e: PythonUDF): Boolean = {
+    e.children match {
+      // single PythonUDF child could be chained and evaluated in Python
+      case Seq(u: PythonUDF) => canEvaluateInPython(u)
+      // Python UDF can't be evaluated directly in JVM
+      case children => !children.exists(hasPythonUDF)
+    }
+  }
+
+  private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
+    expr.collect {
+      case udf: PythonUDF if canEvaluateInPython(udf) => udf
+    }
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     // Skip EvaluatePython nodes.
     case plan: EvaluatePython => plan
 
     case plan: LogicalPlan if plan.resolved =>
       // Extract any PythonUDFs from the current operator.
-      val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => 
udf })
+      val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
       if (udfs.isEmpty) {
         // If there aren't any, we are done.
         plan


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to