This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fbb38ec5bd [SPARK-44249][SQL][PYTHON] Refactor PythonUDTFRunner to 
send its return type separately
1fbb38ec5bd is described below

commit 1fbb38ec5bd4ae8777053b1333fbe62a96f1e0f5
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Jul 3 09:33:45 2023 +0900

    [SPARK-44249][SQL][PYTHON] Refactor PythonUDTFRunner to send its return 
type separately
    
    ### What changes were proposed in this pull request?
    
    Refactors `PythonUDTFRunner` to send its return type separately.
    
    ### Why are the changes needed?
    
    The return type of Python UDTF doesn't need to be included in the Python 
"command" because `PythonUDTF` knows the return type. It can send the return 
type separately.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Updated the related tests and existing tests.
    
    Closes #41792 from ueshin/issues/SPARK-44249/return_type.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/udf.py                          |  8 ++-
 python/pyspark/sql/udtf.py                         |  2 +-
 python/pyspark/worker.py                           | 19 ++-----
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  2 +-
 .../execution/python/BatchEvalPythonUDTFExec.scala | 61 +++++++++++++++-------
 .../sql/execution/python/PythonUDFRunner.scala     | 48 +++++++++++------
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  | 10 ++--
 .../sql/execution/python/PythonUDTFSuite.scala     |  6 ---
 8 files changed, 94 insertions(+), 62 deletions(-)

diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index c6171ffece9..0d235660718 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -51,9 +51,13 @@ __all__ = ["UDFRegistration"]
 
 
 def _wrap_function(
-    sc: SparkContext, func: Callable[..., Any], returnType: "DataTypeOrString"
+    sc: SparkContext, func: Callable[..., Any], returnType: Optional[DataType] 
= None
 ) -> JavaObject:
-    command = (func, returnType)
+    command: Any
+    if returnType is None:
+        command = func
+    else:
+        command = (func, returnType)
     pickled_command, broadcast_vars, env, includes = 
_prepare_for_python_RDD(sc, command)
     assert sc._jvm is not None
     return sc._jvm.SimplePythonFunction(
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 95093970596..3bf7bc977c3 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -118,7 +118,7 @@ class UserDefinedTableFunction:
         spark = SparkSession._getActiveSessionOrCreate()
         sc = spark.sparkContext
 
-        wrapped_func = _wrap_function(sc, func, self.returnType)
+        wrapped_func = _wrap_function(sc, func)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         assert sc._jvm is not None
         judtf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 71a7ccd15aa..b24600b0c1b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -61,10 +61,10 @@ from pyspark.sql.pandas.serializers import (
     ApplyInPandasWithStateSerializer,
 )
 from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import StructType
+from pyspark.sql.types import StructType, _parse_datatype_json_string
 from pyspark.util import fail_on_stopiteration, try_simplify_traceback
 from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError
 
 pickleSer = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
@@ -461,20 +461,11 @@ def assign_cols_by_name(runner_conf):
 # ensure the UDTF is valid. This function also prepares a mapper function for 
applying
 # the UDTF logic to input rows.
 def read_udtf(pickleSer, infile, eval_type):
-    num_udtfs = read_int(infile)
-    if num_udtfs != 1:
-        raise PySparkValueError(f"Unexpected number of UDTFs. Expected 1 but 
got {num_udtfs}.")
-
-    # See `PythonUDFRunner.writeUDFs`.
+    # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for _ in range(num_arg)]
-    num_chained_funcs = read_int(infile)
-    if num_chained_funcs != 1:
-        raise PySparkValueError(
-            f"Unexpected number of chained UDTFs. Expected 1 but got 
{num_chained_funcs}."
-        )
-
-    handler, return_type = read_command(pickleSer, infile)
+    handler = read_command(pickleSer, infile)
+    return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
     if not isinstance(handler, type):
         raise PySparkRuntimeError(
             f"Invalid UDTF handler type. Expected a class (type 'type'), but "
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 6905bde9c33..829438ccec9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -159,7 +159,7 @@ abstract class UnevaluableGenerator extends Generator {
 case class PythonUDTF(
     name: String,
     func: PythonFunction,
-    override val elementSchema: StructType,
+    elementSchema: StructType,
     children: Seq[Expression],
     udfDeterministic: Boolean,
     resultId: ExprId = NamedExpression.newExprId)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index a7fdfb9d173..b233f3983a7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io.File
+import java.io.{DataOutputStream, File}
+import java.net.Socket
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
@@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
 
@@ -69,21 +71,17 @@ case class BatchEvalPythonUDTFExec(
         queue.close()
       }
 
-      val inputs = Seq(udtf.children)
-
       // flatten all the arguments
       val allInputs = new ArrayBuffer[Expression]
       val dataTypes = new ArrayBuffer[DataType]
-      val argOffsets = inputs.map { input =>
-        input.map { e =>
-          if (allInputs.exists(_.semanticEquals(e))) {
-            allInputs.indexWhere(_.semanticEquals(e))
-          } else {
-            allInputs += e
-            dataTypes += e.dataType
-            allInputs.length - 1
-          }
-        }.toArray
+      val argOffsets = udtf.children.map { e =>
+        if (allInputs.exists(_.semanticEquals(e))) {
+          allInputs.indexWhere(_.semanticEquals(e))
+        } else {
+          allInputs += e
+          dataTypes += e.dataType
+          allInputs.length - 1
+        }
       }.toArray
       val projection = MutableProjection.create(allInputs.toSeq, child.output)
       projection.initialize(context.partitionId())
@@ -101,7 +99,7 @@ case class BatchEvalPythonUDTFExec(
         projection(inputRow)
       }
 
-      val outputRowIterator = evaluate(udtf, argOffsets, projectedRowIter, 
schema, context)
+      val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema, 
context)
 
       val pruneChildForResult: InternalRow => InternalRow =
         if (child.outputSet == AttributeSet(requiredChildOutput)) {
@@ -136,8 +134,7 @@ case class BatchEvalPythonUDTFExec(
    * an iterator of internal rows for every input row.
    */
   private def evaluate(
-      udtf: PythonUDTF,
-      argOffsets: Array[Array[Int]],
+      argOffsets: Array[Int],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[Iterator[InternalRow]] = {
@@ -147,9 +144,8 @@ case class BatchEvalPythonUDTFExec(
     val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
 
     // Output iterator for results from Python.
-    val funcs = Seq(ChainedPythonFunctions(Seq(udtf.func)))
     val outputIterator =
-      new PythonUDFRunner(funcs, PythonEvalType.SQL_TABLE_UDF, argOffsets, 
pythonMetrics)
+      new PythonUDTFRunner(udtf, argOffsets, pythonMetrics)
         .compute(inputIterator, context.partitionId(), context)
 
     val unpickle = new Unpickler
@@ -173,3 +169,32 @@ case class BatchEvalPythonUDTFExec(
   override protected def withNewChildInternal(newChild: SparkPlan): 
BatchEvalPythonUDTFExec =
     copy(child = newChild)
 }
+
+class PythonUDTFRunner(
+    udtf: PythonUDTF,
+    argOffsets: Array[Int],
+    pythonMetrics: Map[String, SQLMetric])
+  extends BasePythonUDFRunner(
+    Seq(ChainedPythonFunctions(Seq(udtf.func))),
+    PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics) {
+
+  protected override def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[Array[Byte]],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread = {
+    new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, 
context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        dataOut.writeInt(argOffsets.length)
+        argOffsets.foreach { offset =>
+          dataOut.writeInt(offset)
+        }
+        dataOut.writeInt(udtf.func.command.length)
+        dataOut.write(udtf.func.command.toArray)
+        writeUTF(udtf.elementSchema.json, dataOut)
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 420ab284d53..6a952d9099e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf
 /**
  * A helper class to run Python UDFs in Spark.
  */
-class PythonUDFRunner(
+abstract class BasePythonUDFRunner(
     funcs: Seq[ChainedPythonFunctions],
     evalType: Int,
     argOffsets: Array[Array[Int]],
@@ -43,27 +43,22 @@ class PythonUDFRunner(
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
-  protected override def newWriterThread(
+  abstract class PythonUDFWriterThread(
       env: SparkEnv,
       worker: Socket,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
-      }
+      context: TaskContext)
+    extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
 
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        val startData = dataOut.size()
+    protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+      val startData = dataOut.size()
 
-        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
-        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+      dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
 
-        val deltaData = dataOut.size() - startData
-        pythonMetrics("pythonDataSent") += deltaData
-      }
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
     }
   }
 
@@ -106,6 +101,29 @@ class PythonUDFRunner(
   }
 }
 
+class PythonUDFRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    pythonMetrics: Map[String, SQLMetric])
+  extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics) {
+
+  protected override def newWriterThread(
+      env: SparkEnv,
+      worker: Socket,
+      inputIterator: Iterator[Array[Byte]],
+      partitionIndex: Int,
+      context: TaskContext): WriterThread = {
+    new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, 
context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+      }
+
+    }
+  }
+}
+
 object PythonUDFRunner {
 
   def writeUDFs(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index c76e01a59d6..00962e77185 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -196,7 +196,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
     if (!shouldTestPythonUDFs) {
       throw new RuntimeException(s"Python executable [$pythonExec] and/or 
pyspark are unavailable.")
     }
-    var binaryPandasFunc: Array[Byte] = null
+    var binaryPythonUDTF: Array[Byte] = null
     withTempPath { codePath =>
       Files.write(codePath.toPath, 
pythonScript.getBytes(StandardCharsets.UTF_8))
       withTempPath { path =>
@@ -208,14 +208,14 @@ object IntegratedUDFTestUtils extends SQLHelper {
               s"f = open('$path', 'wb');" +
               s"exec(open('$codePath', 'r').read());" +
               "f.write(CloudPickleSerializer().dumps(" +
-              s"($funcName, returnType)))"),
+              s"$funcName))"),
           None,
           "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
-        binaryPandasFunc = Files.readAllBytes(path.toPath)
+        binaryPythonUDTF = Files.readAllBytes(path.toPath)
       }
     }
-    assert(binaryPandasFunc != null)
-    binaryPandasFunc
+    assert(binaryPythonUDTF != null)
+    binaryPythonUDTF
   }
 
   private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index fa4d80c331a..9bd0a13a3a1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -30,12 +30,6 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
 
   private val pythonScript: String =
     """
-      |from pyspark.sql.types import StructType, StructField, IntegerType
-      |returnType = StructType([
-      |  StructField("a", IntegerType()),
-      |  StructField("b", IntegerType()),
-      |  StructField("c", IntegerType()),
-      |])
       |class SimpleUDTF:
       |    def eval(self, a: int, b: int):
       |        yield a, b, a + b


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to