This is an automated email from the ASF dual-hosted git repository. allisonwang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d2b49662fa24 [SPARK-52981][PYTHON] Add table argument support for Arrow Python UDTFs d2b49662fa24 is described below commit d2b49662fa243bd9407f1bfcfa1f14ace4b60a76 Author: Allison Wang <allison.w...@databricks.com> AuthorDate: Tue Aug 12 14:25:12 2025 -0700 [SPARK-52981][PYTHON] Add table argument support for Arrow Python UDTFs ### What changes were proposed in this pull request? This PR adds support for table argument in Arrow Python UDTFs. ``` arrow_udtf(returnType="..") class TableArgUDTF: def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: ... ``` ### Why are the changes needed? To support table argument in Arrow Python UDTF. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51860 from allisonwang-db/spark-52981-udtf-table-args. Authored-by: Allison Wang <allison.w...@databricks.com> Signed-off-by: Allison Wang <allison.w...@databricks.com> --- python/pyspark/sql/pandas/serializers.py | 28 +++++++- python/pyspark/sql/tests/arrow/test_arrow_udtf.py | 79 ++++++++++++++++++++++ python/pyspark/worker.py | 7 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/expressions/PythonUDF.scala | 7 +- .../execution/python/ArrowPythonUDTFRunner.scala | 8 +++ .../execution/python/BatchEvalPythonUDTFExec.scala | 2 +- .../sql/execution/python/EvalPythonExec.scala | 3 +- .../sql/execution/python/EvalPythonUDTFExec.scala | 8 ++- .../sql/execution/python/PythonUDFRunner.scala | 2 +- .../python/UserDefinedPythonFunction.scala | 21 +++--- 11 files changed, 145 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 769f5e043a77..73546d2320bd 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -201,8 +201,31 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. """ - # TODO(SPARK-52981): support table arguments - ... + def __init__(self, table_arg_offsets=None): + super().__init__() + self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] + + def load_stream(self, stream): + """ + Flatten the struct into Arrow's record batches. + """ + import pyarrow as pa + + batches = super().load_stream(stream) + for batch in batches: + result_batches = [] + for i in range(batch.num_columns): + if i in self.table_arg_offsets: + struct = batch.column(i) + # Flatten the struct and create a RecordBatch from it + flattened_batch = pa.RecordBatch.from_arrays( + struct.flatten(), schema=pa.schema(struct.type) + ) + result_batches.append(flattened_batch) + else: + # Keep the column as it is for non-table columns + result_batches.append(batch.column(i)) + yield result_batches class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): @@ -1584,6 +1607,7 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe def generate_data_batches(batches): """ Deserialize ArrowRecordBatches and return a generator of pandas.Series list. + The deserialization logic assumes that Arrow RecordBatches contain the data with the ordering that data chunks for same grouping key will appear sequentially. See `TransformWithStateInPandasPythonInitialStateRunner` for arrow batch schema sent diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index da15d48fceda..c7274ebebd83 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -385,8 +385,87 @@ class ArrowUDTFTests(ReusedSQLTestCase): self.assertIn("INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE", str(cm.exception)) + def test_arrow_udtf_with_table_argument_basic(self): + @arrow_udtf(returnType="filtered_id bigint") # Use bigint to match int64 + class TableArgUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch, got {type(table_data)}" + + # Convert record batch to table to work with it more easily + table = pa.table(table_data) + + # Filter rows where id > 5 + id_column = table.column("id") + mask = pa.compute.greater(id_column, pa.scalar(5)) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type (int64) + ) + yield result_table + + # TODO(SPARK-53251): Enable DataFrame API testing with asTable() + # # Test with DataFrame API using asTable() + # input_df = self.spark.range(8) + # result_df = TableArgUDTF(input_df.asTable()) + expected_df = self.spark.createDataFrame([(6,), (7,)], "filtered_id bigint") + # assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage with TABLE() syntax + self.spark.udtf.register("test_table_arg_udtf", TableArgUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_table_arg_udtf(TABLE(SELECT id FROM range(0, 8)))" + ) + assertDataFrameEqual(sql_result_df, expected_df) + + def test_arrow_udtf_with_table_argument_and_scalar(self): + @arrow_udtf(returnType="filtered_id bigint") # Use bigint to match int64 + class MixedArgsUDTF: + def eval( + self, table_data: "pa.RecordBatch", threshold: "pa.Array" + ) -> Iterator["pa.Table"]: + assert isinstance( + threshold, pa.Array + ), f"Expected pa.Array for threshold, got {type(threshold)}" + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + + threshold_val = threshold[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + mask = pa.compute.greater(id_column, pa.scalar(threshold_val)) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type + ) + yield result_table + + # # Test with DataFrame API + # TODO(SPARK-53251): Enable DataFrame API testing with asTable() + # input_df = self.spark.range(8) + # result_df = MixedArgsUDTF(input_df.asTable(), lit(5)) + expected_df = self.spark.createDataFrame([(6,), (7,)], "filtered_id bigint") + # assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage + self.spark.udtf.register("test_mixed_args_udtf", MixedArgsUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_mixed_args_udtf(TABLE(SELECT id FROM range(0, 8)), 5)" + ) + assertDataFrameEqual(sql_result_df, expected_df) + if __name__ == "__main__": + from pyspark.sql.tests.arrow.test_arrow_udtf import * # noqa: F401 + try: import xmlrunner diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 207c2a999571..abdc956188e4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1340,8 +1340,11 @@ def read_udtf(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v prefers_large_var_types = use_large_var_types(runner_conf) - # Use PyArrow-native serializer for Arrow UDTFs - ser = ArrowStreamArrowUDTFSerializer() + # Read the table argument offsets + num_table_arg_offsets = read_int(infile) + table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] + # Use PyArrow-native serializer for Arrow UDTFs with potential UDT support + ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e49e6aa7f044..7d243c227cdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2289,7 +2289,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } PythonUDTF( u.name, u.func, analyzeResult.schema, Some(analyzeResult.pickledAnalyzeResult), - newChildren, u.evalType, u.udfDeterministic, u.resultId) + newChildren, u.evalType, u.udfDeterministic, u.resultId, None, u.tableArguments) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index bacb350b2895..e4d0f9642773 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -176,6 +176,7 @@ abstract class UnevaluableGenerator extends Generator { * @param pythonUDTFPartitionColumnIndexes holds the zero-based indexes of the projected results of * all PARTITION BY expressions within the TABLE argument of * the Python UDTF call, if applicable + * @param tableArguments holds whether an input argument is a table argument */ case class PythonUDTF( name: String, @@ -186,7 +187,8 @@ case class PythonUDTF( evalType: Int, udfDeterministic: Boolean, resultId: ExprId = NamedExpression.newExprId, - pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) + pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None, + tableArguments: Option[Seq[Boolean]] = None) extends UnevaluableGenerator with PythonFuncExpression { override lazy val canonicalized: Expression = { @@ -215,7 +217,8 @@ case class UnresolvedPolymorphicPythonUDTF( evalType: Int, udfDeterministic: Boolean, resolveElementMetadata: (PythonFunction, Seq[Expression]) => PythonUDTFAnalyzeResult, - resultId: ExprId = NamedExpression.newExprId) + resultId: ExprId = NamedExpression.newExprId, + tableArguments: Option[Seq[Boolean]] = None) extends UnevaluableGenerator with PythonFuncExpression { override lazy val resolved = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index f440edc83f6c..c081787b5209 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -54,6 +54,14 @@ class ArrowPythonUDTFRunner( if (evalType == PythonEvalType.SQL_ARROW_TABLE_UDF) { PythonWorkerUtils.writeUTF(schema.json, dataOut) } + // Write the table argument offsets for Arrow UDTFs. + else if (evalType == PythonEvalType.SQL_ARROW_UDTF) { + val tableArgOffsets = argMetas.collect { + case ArgumentMetadata(offset, _, isTableArg) if isTableArg => offset + } + dataOut.writeInt(tableArgOffsets.length) + tableArgOffsets.foreach(dataOut.writeInt(_)) + } PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index c0dcb7781742..a1358c9cd774 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -120,7 +120,7 @@ object PythonUDTFRunner { // Write the argument types of the UDTF. dataOut.writeInt(argMetas.length) argMetas.foreach { - case ArgumentMetadata(offset, name) => + case ArgumentMetadata(offset, name, _) => dataOut.writeInt(offset) name match { case Some(name) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index af6769cfbb9d..0c366b1280b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -28,8 +28,9 @@ object EvalPythonExec { * * @param offset the offset of the argument * @param name the name of the argument if it's a `NamedArgumentExpression` + * @param isTableArg whether this argument is a table argument */ - case class ArgumentMetadata(offset: Int, name: Option[String]) + case class ArgumentMetadata(offset: Int, name: Option[String], isTableArg: Boolean = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index 41a99693443e..3cb9431fed6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -68,7 +68,9 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argMetas = udtf.children.map { e => + val argMetas = udtf.children.zip( + udtf.tableArguments.getOrElse(Seq.fill(udtf.children.length)(false)) + ).map { case (e: Expression, isTableArg: Boolean) => val (key, value) = e match { case NamedArgumentExpression(key, value) => (Some(key), value) @@ -76,11 +78,11 @@ trait EvalPythonUDTFExec extends UnaryExecNode { (None, e) } if (allInputs.exists(_.semanticEquals(value))) { - ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key, isTableArg) } else { allInputs += value dataTypes += value.dataType - ArgumentMetadata(allInputs.length - 1, key) + ArgumentMetadata(allInputs.length - 1, key, isTableArg) } }.toArray val projection = MutableProjection.create(allInputs.toSeq, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 3f30519e9521..8ff7e57d9421 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -193,7 +193,7 @@ object PythonUDFRunner { funcs.zip(argMetas).foreach { case ((chained, resultId), metas) => dataOut.writeInt(metas.length) metas.foreach { - case ArgumentMetadata(offset, name) => + case ArgumentMetadata(offset, name, _) => dataOut.writeInt(offset) name match { case Some(name) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 32d8e8a21336..8f04af8295da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -123,6 +123,14 @@ case class UserDefinedPythonTableFunction( */ NamedParametersSupport.splitAndCheckNamedArguments(exprs, name) + // Check which argument is a table argument here since it will be replaced with + // `UnresolvedAttribute` to construct lateral join. + val tableArgs = exprs.map { + case _: FunctionTableSubqueryArgumentExpression => true + case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true + case _ => false + } + val udtf = returnType match { case Some(rt) => PythonUDTF( @@ -132,15 +140,9 @@ case class UserDefinedPythonTableFunction( pickledAnalyzeResult = None, children = exprs, evalType = pythonEvalType, - udfDeterministic = udfDeterministic) + udfDeterministic = udfDeterministic, + tableArguments = Some(tableArgs)) case _ => - // Check which argument is a table argument here since it will be replaced with - // `UnresolvedAttribute` to construct lateral join. - val tableArgs = exprs.map { - case _: FunctionTableSubqueryArgumentExpression => true - case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true - case _ => false - } val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => { val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(name, func, exprs, tableArgs, parser) @@ -152,7 +154,8 @@ case class UserDefinedPythonTableFunction( children = exprs, evalType = pythonEvalType, udfDeterministic = udfDeterministic, - resolveElementMetadata = runAnalyzeInPython) + resolveElementMetadata = runAnalyzeInPython, + tableArguments = Some(tableArgs)) } Generate( udtf, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org