This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1fbb38ec5bd [SPARK-44249][SQL][PYTHON] Refactor PythonUDTFRunner to
send its return type separately
1fbb38ec5bd is described below
commit 1fbb38ec5bd4ae8777053b1333fbe62a96f1e0f5
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Jul 3 09:33:45 2023 +0900
[SPARK-44249][SQL][PYTHON] Refactor PythonUDTFRunner to send its return
type separately
### What changes were proposed in this pull request?
Refactors `PythonUDTFRunner` to send its return type separately.
### Why are the changes needed?
The return type of Python UDTF doesn't need to be included in the Python
"command" because `PythonUDTF` knows the return type. It can send the return
type separately.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Updated the related tests and existing tests.
Closes #41792 from ueshin/issues/SPARK-44249/return_type.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/udf.py | 8 ++-
python/pyspark/sql/udtf.py | 2 +-
python/pyspark/worker.py | 19 ++-----
.../spark/sql/catalyst/expressions/PythonUDF.scala | 2 +-
.../execution/python/BatchEvalPythonUDTFExec.scala | 61 +++++++++++++++-------
.../sql/execution/python/PythonUDFRunner.scala | 48 +++++++++++------
.../apache/spark/sql/IntegratedUDFTestUtils.scala | 10 ++--
.../sql/execution/python/PythonUDTFSuite.scala | 6 ---
8 files changed, 94 insertions(+), 62 deletions(-)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index c6171ffece9..0d235660718 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -51,9 +51,13 @@ __all__ = ["UDFRegistration"]
def _wrap_function(
- sc: SparkContext, func: Callable[..., Any], returnType: "DataTypeOrString"
+ sc: SparkContext, func: Callable[..., Any], returnType: Optional[DataType]
= None
) -> JavaObject:
- command = (func, returnType)
+ command: Any
+ if returnType is None:
+ command = func
+ else:
+ command = (func, returnType)
pickled_command, broadcast_vars, env, includes =
_prepare_for_python_RDD(sc, command)
assert sc._jvm is not None
return sc._jvm.SimplePythonFunction(
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 95093970596..3bf7bc977c3 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -118,7 +118,7 @@ class UserDefinedTableFunction:
spark = SparkSession._getActiveSessionOrCreate()
sc = spark.sparkContext
- wrapped_func = _wrap_function(sc, func, self.returnType)
+ wrapped_func = _wrap_function(sc, func)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
assert sc._jvm is not None
judtf =
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 71a7ccd15aa..b24600b0c1b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -61,10 +61,10 @@ from pyspark.sql.pandas.serializers import (
ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import StructType
+from pyspark.sql.types import StructType, _parse_datatype_json_string
from pyspark.util import fail_on_stopiteration, try_simplify_traceback
from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.errors import PySparkRuntimeError
pickleSer = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
@@ -461,20 +461,11 @@ def assign_cols_by_name(runner_conf):
# ensure the UDTF is valid. This function also prepares a mapper function for
applying
# the UDTF logic to input rows.
def read_udtf(pickleSer, infile, eval_type):
- num_udtfs = read_int(infile)
- if num_udtfs != 1:
- raise PySparkValueError(f"Unexpected number of UDTFs. Expected 1 but
got {num_udtfs}.")
-
- # See `PythonUDFRunner.writeUDFs`.
+ # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for _ in range(num_arg)]
- num_chained_funcs = read_int(infile)
- if num_chained_funcs != 1:
- raise PySparkValueError(
- f"Unexpected number of chained UDTFs. Expected 1 but got
{num_chained_funcs}."
- )
-
- handler, return_type = read_command(pickleSer, infile)
+ handler = read_command(pickleSer, infile)
+ return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(handler, type):
raise PySparkRuntimeError(
f"Invalid UDTF handler type. Expected a class (type 'type'), but "
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 6905bde9c33..829438ccec9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -159,7 +159,7 @@ abstract class UnevaluableGenerator extends Generator {
case class PythonUDTF(
name: String,
func: PythonFunction,
- override val elementSchema: StructType,
+ elementSchema: StructType,
children: Seq[Expression],
udfDeterministic: Boolean,
resultId: ExprId = NamedExpression.newExprId)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index a7fdfb9d173..b233f3983a7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution.python
-import java.io.File
+import java.io.{DataOutputStream, File}
+import java.net.Socket
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
@@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -69,21 +71,17 @@ case class BatchEvalPythonUDTFExec(
queue.close()
}
- val inputs = Seq(udtf.children)
-
// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
- val argOffsets = inputs.map { input =>
- input.map { e =>
- if (allInputs.exists(_.semanticEquals(e))) {
- allInputs.indexWhere(_.semanticEquals(e))
- } else {
- allInputs += e
- dataTypes += e.dataType
- allInputs.length - 1
- }
- }.toArray
+ val argOffsets = udtf.children.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
}.toArray
val projection = MutableProjection.create(allInputs.toSeq, child.output)
projection.initialize(context.partitionId())
@@ -101,7 +99,7 @@ case class BatchEvalPythonUDTFExec(
projection(inputRow)
}
- val outputRowIterator = evaluate(udtf, argOffsets, projectedRowIter,
schema, context)
+ val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema,
context)
val pruneChildForResult: InternalRow => InternalRow =
if (child.outputSet == AttributeSet(requiredChildOutput)) {
@@ -136,8 +134,7 @@ case class BatchEvalPythonUDTFExec(
* an iterator of internal rows for every input row.
*/
private def evaluate(
- udtf: PythonUDTF,
- argOffsets: Array[Array[Int]],
+ argOffsets: Array[Int],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[Iterator[InternalRow]] = {
@@ -147,9 +144,8 @@ case class BatchEvalPythonUDTFExec(
val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema)
// Output iterator for results from Python.
- val funcs = Seq(ChainedPythonFunctions(Seq(udtf.func)))
val outputIterator =
- new PythonUDFRunner(funcs, PythonEvalType.SQL_TABLE_UDF, argOffsets,
pythonMetrics)
+ new PythonUDTFRunner(udtf, argOffsets, pythonMetrics)
.compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
@@ -173,3 +169,32 @@ case class BatchEvalPythonUDTFExec(
override protected def withNewChildInternal(newChild: SparkPlan):
BatchEvalPythonUDTFExec =
copy(child = newChild)
}
+
+class PythonUDTFRunner(
+ udtf: PythonUDTF,
+ argOffsets: Array[Int],
+ pythonMetrics: Map[String, SQLMetric])
+ extends BasePythonUDFRunner(
+ Seq(ChainedPythonFunctions(Seq(udtf.func))),
+ PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics) {
+
+ protected override def newWriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[Array[Byte]],
+ partitionIndex: Int,
+ context: TaskContext): WriterThread = {
+ new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex,
context) {
+
+ protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+ dataOut.writeInt(argOffsets.length)
+ argOffsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(udtf.func.command.length)
+ dataOut.write(udtf.func.command.toArray)
+ writeUTF(udtf.elementSchema.json, dataOut)
+ }
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 420ab284d53..6a952d9099e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf
/**
* A helper class to run Python UDFs in Spark.
*/
-class PythonUDFRunner(
+abstract class BasePythonUDFRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
@@ -43,27 +43,22 @@ class PythonUDFRunner(
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
- protected override def newWriterThread(
+ abstract class PythonUDFWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Array[Byte]],
partitionIndex: Int,
- context: TaskContext): WriterThread = {
- new WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
- protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
- }
+ context: TaskContext)
+ extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
- protected override def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
- val startData = dataOut.size()
+ protected override def writeIteratorToStream(dataOut: DataOutputStream):
Unit = {
+ val startData = dataOut.size()
- PythonRDD.writeIteratorToStream(inputIterator, dataOut)
- dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
- val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
- }
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
}
}
@@ -106,6 +101,29 @@ class PythonUDFRunner(
}
}
+class PythonUDFRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ pythonMetrics: Map[String, SQLMetric])
+ extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics) {
+
+ protected override def newWriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[Array[Byte]],
+ partitionIndex: Int,
+ context: TaskContext): WriterThread = {
+ new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex,
context) {
+
+ protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+ }
+
+ }
+ }
+}
+
object PythonUDFRunner {
def writeUDFs(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index c76e01a59d6..00962e77185 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -196,7 +196,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
if (!shouldTestPythonUDFs) {
throw new RuntimeException(s"Python executable [$pythonExec] and/or
pyspark are unavailable.")
}
- var binaryPandasFunc: Array[Byte] = null
+ var binaryPythonUDTF: Array[Byte] = null
withTempPath { codePath =>
Files.write(codePath.toPath,
pythonScript.getBytes(StandardCharsets.UTF_8))
withTempPath { path =>
@@ -208,14 +208,14 @@ object IntegratedUDFTestUtils extends SQLHelper {
s"f = open('$path', 'wb');" +
s"exec(open('$codePath', 'r').read());" +
"f.write(CloudPickleSerializer().dumps(" +
- s"($funcName, returnType)))"),
+ s"$funcName))"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
- binaryPandasFunc = Files.readAllBytes(path.toPath)
+ binaryPythonUDTF = Files.readAllBytes(path.toPath)
}
}
- assert(binaryPandasFunc != null)
- binaryPandasFunc
+ assert(binaryPythonUDTF != null)
+ binaryPythonUDTF
}
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index fa4d80c331a..9bd0a13a3a1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -30,12 +30,6 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
private val pythonScript: String =
"""
- |from pyspark.sql.types import StructType, StructField, IntegerType
- |returnType = StructType([
- | StructField("a", IntegerType()),
- | StructField("b", IntegerType()),
- | StructField("c", IntegerType()),
- |])
|class SimpleUDTF:
| def eval(self, a: int, b: int):
| yield a, b, a + b
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]