This is an automated email from the ASF dual-hosted git repository.

allisonwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d2b49662fa24 [SPARK-52981][PYTHON] Add table argument support for 
Arrow Python UDTFs
d2b49662fa24 is described below

commit d2b49662fa243bd9407f1bfcfa1f14ace4b60a76
Author: Allison Wang <allison.w...@databricks.com>
AuthorDate: Tue Aug 12 14:25:12 2025 -0700

    [SPARK-52981][PYTHON] Add table argument support for Arrow Python UDTFs
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for table argument in Arrow Python UDTFs.
    
    ```
    arrow_udtf(returnType="..")
    class TableArgUDTF:
        def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
            ...
    ```
    
    ### Why are the changes needed?
    
    To support table argument in Arrow Python UDTF.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    New UTs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51860 from allisonwang-db/spark-52981-udtf-table-args.
    
    Authored-by: Allison Wang <allison.w...@databricks.com>
    Signed-off-by: Allison Wang <allison.w...@databricks.com>
---
 python/pyspark/sql/pandas/serializers.py           | 28 +++++++-
 python/pyspark/sql/tests/arrow/test_arrow_udtf.py  | 79 ++++++++++++++++++++++
 python/pyspark/worker.py                           |  7 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  2 +-
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  7 +-
 .../execution/python/ArrowPythonUDTFRunner.scala   |  8 +++
 .../execution/python/BatchEvalPythonUDTFExec.scala |  2 +-
 .../sql/execution/python/EvalPythonExec.scala      |  3 +-
 .../sql/execution/python/EvalPythonUDTFExec.scala  |  8 ++-
 .../sql/execution/python/PythonUDFRunner.scala     |  2 +-
 .../python/UserDefinedPythonFunction.scala         | 21 +++---
 11 files changed, 145 insertions(+), 22 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 769f5e043a77..73546d2320bd 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -201,8 +201,31 @@ class 
ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
     Serializer for PyArrow-native UDTFs that work directly with PyArrow 
RecordBatches and Arrays.
     """
 
-    # TODO(SPARK-52981): support table arguments
-    ...
+    def __init__(self, table_arg_offsets=None):
+        super().__init__()
+        self.table_arg_offsets = table_arg_offsets if table_arg_offsets else []
+
+    def load_stream(self, stream):
+        """
+        Flatten the struct into Arrow's record batches.
+        """
+        import pyarrow as pa
+
+        batches = super().load_stream(stream)
+        for batch in batches:
+            result_batches = []
+            for i in range(batch.num_columns):
+                if i in self.table_arg_offsets:
+                    struct = batch.column(i)
+                    # Flatten the struct and create a RecordBatch from it
+                    flattened_batch = pa.RecordBatch.from_arrays(
+                        struct.flatten(), schema=pa.schema(struct.type)
+                    )
+                    result_batches.append(flattened_batch)
+                else:
+                    # Keep the column as it is for non-table columns
+                    result_batches.append(batch.column(i))
+            yield result_batches
 
 
 class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
@@ -1584,6 +1607,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
         def generate_data_batches(batches):
             """
             Deserialize ArrowRecordBatches and return a generator of 
pandas.Series list.
+
             The deserialization logic assumes that Arrow RecordBatches contain 
the data with the
             ordering that data chunks for same grouping key will appear 
sequentially.
             See `TransformWithStateInPandasPythonInitialStateRunner` for arrow 
batch schema sent
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index da15d48fceda..c7274ebebd83 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -385,8 +385,87 @@ class ArrowUDTFTests(ReusedSQLTestCase):
 
         self.assertIn("INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE", 
str(cm.exception))
 
+    def test_arrow_udtf_with_table_argument_basic(self):
+        @arrow_udtf(returnType="filtered_id bigint")  # Use bigint to match 
int64
+        class TableArgUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                assert isinstance(
+                    table_data, pa.RecordBatch
+                ), f"Expected pa.RecordBatch, got {type(table_data)}"
+
+                # Convert record batch to table to work with it more easily
+                table = pa.table(table_data)
+
+                # Filter rows where id > 5
+                id_column = table.column("id")
+                mask = pa.compute.greater(id_column, pa.scalar(5))
+                filtered_table = table.filter(mask)
+
+                if filtered_table.num_rows > 0:
+                    result_table = pa.table(
+                        {"filtered_id": filtered_table.column("id")}  # Keep 
original type (int64)
+                    )
+                    yield result_table
+
+        # TODO(SPARK-53251): Enable DataFrame API testing with asTable()
+        # # Test with DataFrame API using asTable()
+        # input_df = self.spark.range(8)
+        # result_df = TableArgUDTF(input_df.asTable())
+        expected_df = self.spark.createDataFrame([(6,), (7,)], "filtered_id 
bigint")
+        # assertDataFrameEqual(result_df, expected_df)
+
+        # Test SQL registration and usage with TABLE() syntax
+        self.spark.udtf.register("test_table_arg_udtf", TableArgUDTF)
+        sql_result_df = self.spark.sql(
+            "SELECT * FROM test_table_arg_udtf(TABLE(SELECT id FROM range(0, 
8)))"
+        )
+        assertDataFrameEqual(sql_result_df, expected_df)
+
+    def test_arrow_udtf_with_table_argument_and_scalar(self):
+        @arrow_udtf(returnType="filtered_id bigint")  # Use bigint to match 
int64
+        class MixedArgsUDTF:
+            def eval(
+                self, table_data: "pa.RecordBatch", threshold: "pa.Array"
+            ) -> Iterator["pa.Table"]:
+                assert isinstance(
+                    threshold, pa.Array
+                ), f"Expected pa.Array for threshold, got {type(threshold)}"
+                assert isinstance(
+                    table_data, pa.RecordBatch
+                ), f"Expected pa.RecordBatch for table_data, got 
{type(table_data)}"
+
+                threshold_val = threshold[0].as_py()
+
+                # Convert record batch to table
+                table = pa.table(table_data)
+                id_column = table.column("id")
+                mask = pa.compute.greater(id_column, pa.scalar(threshold_val))
+                filtered_table = table.filter(mask)
+
+                if filtered_table.num_rows > 0:
+                    result_table = pa.table(
+                        {"filtered_id": filtered_table.column("id")}  # Keep 
original type
+                    )
+                    yield result_table
+
+        # # Test with DataFrame API
+        # TODO(SPARK-53251): Enable DataFrame API testing with asTable()
+        # input_df = self.spark.range(8)
+        # result_df = MixedArgsUDTF(input_df.asTable(), lit(5))
+        expected_df = self.spark.createDataFrame([(6,), (7,)], "filtered_id 
bigint")
+        # assertDataFrameEqual(result_df, expected_df)
+
+        # Test SQL registration and usage
+        self.spark.udtf.register("test_mixed_args_udtf", MixedArgsUDTF)
+        sql_result_df = self.spark.sql(
+            "SELECT * FROM test_mixed_args_udtf(TABLE(SELECT id FROM range(0, 
8)), 5)"
+        )
+        assertDataFrameEqual(sql_result_df, expected_df)
+
 
 if __name__ == "__main__":
+    from pyspark.sql.tests.arrow.test_arrow_udtf import *  # noqa: F401
+
     try:
         import xmlrunner
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 207c2a999571..abdc956188e4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1340,8 +1340,11 @@ def read_udtf(pickleSer, infile, eval_type):
             v = utf8_deserializer.loads(infile)
             runner_conf[k] = v
         prefers_large_var_types = use_large_var_types(runner_conf)
-        # Use PyArrow-native serializer for Arrow UDTFs
-        ser = ArrowStreamArrowUDTFSerializer()
+        # Read the table argument offsets
+        num_table_arg_offsets = read_int(infile)
+        table_arg_offsets = [read_int(infile) for _ in 
range(num_table_arg_offsets)]
+        # Use PyArrow-native serializer for Arrow UDTFs with potential UDT 
support
+        ser = 
ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
     else:
         # Each row is a group so do not batch but send one by one.
         ser = BatchedSerializer(CPickleSerializer(), 1)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e49e6aa7f044..7d243c227cdf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2289,7 +2289,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
             }
             PythonUDTF(
               u.name, u.func, analyzeResult.schema, 
Some(analyzeResult.pickledAnalyzeResult),
-              newChildren, u.evalType, u.udfDeterministic, u.resultId)
+              newChildren, u.evalType, u.udfDeterministic, u.resultId, None, 
u.tableArguments)
           }
         }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index bacb350b2895..e4d0f9642773 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -176,6 +176,7 @@ abstract class UnevaluableGenerator extends Generator {
  * @param pythonUDTFPartitionColumnIndexes holds the zero-based indexes of the 
projected results of
  *                                         all PARTITION BY expressions within 
the TABLE argument of
  *                                         the Python UDTF call, if applicable
+ * @param tableArguments holds whether an input argument is a table argument
  */
 case class PythonUDTF(
     name: String,
@@ -186,7 +187,8 @@ case class PythonUDTF(
     evalType: Int,
     udfDeterministic: Boolean,
     resultId: ExprId = NamedExpression.newExprId,
-    pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] 
= None)
+    pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] 
= None,
+    tableArguments: Option[Seq[Boolean]] = None)
   extends UnevaluableGenerator with PythonFuncExpression {
 
   override lazy val canonicalized: Expression = {
@@ -215,7 +217,8 @@ case class UnresolvedPolymorphicPythonUDTF(
     evalType: Int,
     udfDeterministic: Boolean,
     resolveElementMetadata: (PythonFunction, Seq[Expression]) => 
PythonUDTFAnalyzeResult,
-    resultId: ExprId = NamedExpression.newExprId)
+    resultId: ExprId = NamedExpression.newExprId,
+    tableArguments: Option[Seq[Boolean]] = None)
   extends UnevaluableGenerator with PythonFuncExpression {
 
   override lazy val resolved = false
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index f440edc83f6c..c081787b5209 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -54,6 +54,14 @@ class ArrowPythonUDTFRunner(
     if (evalType == PythonEvalType.SQL_ARROW_TABLE_UDF) {
       PythonWorkerUtils.writeUTF(schema.json, dataOut)
     }
+    // Write the table argument offsets for Arrow UDTFs.
+    else if (evalType == PythonEvalType.SQL_ARROW_UDTF) {
+      val tableArgOffsets = argMetas.collect {
+        case ArgumentMetadata(offset, _, isTableArg) if isTableArg => offset
+      }
+      dataOut.writeInt(tableArgOffsets.length)
+      tableArgOffsets.foreach(dataOut.writeInt(_))
+    }
     PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index c0dcb7781742..a1358c9cd774 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -120,7 +120,7 @@ object PythonUDTFRunner {
     // Write the argument types of the UDTF.
     dataOut.writeInt(argMetas.length)
     argMetas.foreach {
-      case ArgumentMetadata(offset, name) =>
+      case ArgumentMetadata(offset, name, _) =>
         dataOut.writeInt(offset)
         name match {
           case Some(name) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index af6769cfbb9d..0c366b1280b4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -28,8 +28,9 @@ object EvalPythonExec {
    *
    * @param offset the offset of the argument
    * @param name the name of the argument if it's a `NamedArgumentExpression`
+   * @param isTableArg whether this argument is a table argument
    */
-  case class ArgumentMetadata(offset: Int, name: Option[String])
+  case class ArgumentMetadata(offset: Int, name: Option[String], isTableArg: 
Boolean = false)
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
index 41a99693443e..3cb9431fed6f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
@@ -68,7 +68,9 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
       // flatten all the arguments
       val allInputs = new ArrayBuffer[Expression]
       val dataTypes = new ArrayBuffer[DataType]
-      val argMetas = udtf.children.map { e =>
+      val argMetas = udtf.children.zip(
+        udtf.tableArguments.getOrElse(Seq.fill(udtf.children.length)(false))
+      ).map { case (e: Expression, isTableArg: Boolean) =>
         val (key, value) = e match {
           case NamedArgumentExpression(key, value) =>
             (Some(key), value)
@@ -76,11 +78,11 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
             (None, e)
         }
         if (allInputs.exists(_.semanticEquals(value))) {
-          ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key)
+          ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key, 
isTableArg)
         } else {
           allInputs += value
           dataTypes += value.dataType
-          ArgumentMetadata(allInputs.length - 1, key)
+          ArgumentMetadata(allInputs.length - 1, key, isTableArg)
         }
       }.toArray
       val projection = MutableProjection.create(allInputs.toSeq, child.output)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 3f30519e9521..8ff7e57d9421 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -193,7 +193,7 @@ object PythonUDFRunner {
     funcs.zip(argMetas).foreach { case ((chained, resultId), metas) =>
       dataOut.writeInt(metas.length)
       metas.foreach {
-        case ArgumentMetadata(offset, name) =>
+        case ArgumentMetadata(offset, name, _) =>
           dataOut.writeInt(offset)
           name match {
             case Some(name) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 32d8e8a21336..8f04af8295da 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -123,6 +123,14 @@ case class UserDefinedPythonTableFunction(
      */
     NamedParametersSupport.splitAndCheckNamedArguments(exprs, name)
 
+    // Check which argument is a table argument here since it will be replaced 
with
+    // `UnresolvedAttribute` to construct lateral join.
+    val tableArgs = exprs.map {
+      case _: FunctionTableSubqueryArgumentExpression => true
+      case NamedArgumentExpression(_, _: 
FunctionTableSubqueryArgumentExpression) => true
+      case _ => false
+    }
+
     val udtf = returnType match {
       case Some(rt) =>
         PythonUDTF(
@@ -132,15 +140,9 @@ case class UserDefinedPythonTableFunction(
           pickledAnalyzeResult = None,
           children = exprs,
           evalType = pythonEvalType,
-          udfDeterministic = udfDeterministic)
+          udfDeterministic = udfDeterministic,
+          tableArguments = Some(tableArgs))
       case _ =>
-        // Check which argument is a table argument here since it will be 
replaced with
-        // `UnresolvedAttribute` to construct lateral join.
-        val tableArgs = exprs.map {
-          case _: FunctionTableSubqueryArgumentExpression => true
-          case NamedArgumentExpression(_, _: 
FunctionTableSubqueryArgumentExpression) => true
-          case _ => false
-        }
         val runAnalyzeInPython = (func: PythonFunction, exprs: 
Seq[Expression]) => {
           val runner =
             new UserDefinedPythonTableFunctionAnalyzeRunner(name, func, exprs, 
tableArgs, parser)
@@ -152,7 +154,8 @@ case class UserDefinedPythonTableFunction(
           children = exprs,
           evalType = pythonEvalType,
           udfDeterministic = udfDeterministic,
-          resolveElementMetadata = runAnalyzeInPython)
+          resolveElementMetadata = runAnalyzeInPython,
+          tableArguments = Some(tableArgs))
     }
     Generate(
       udtf,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to