This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 25c57d87382c [SPARK-54589][PYTHON] Consolidate
ArrowStreamAggPandasIterUDFSerializer into ArrowStreamAggPandasUDFSerializer
25c57d87382c is described below
commit 25c57d87382ca90850541d4d20447bfd842ec85a
Author: Yicong-Huang <[email protected]>
AuthorDate: Tue Dec 16 12:32:11 2025 +0800
[SPARK-54589][PYTHON] Consolidate ArrowStreamAggPandasIterUDFSerializer
into ArrowStreamAggPandasUDFSerializer
### What changes were proposed in this pull request?
This PR consolidates `ArrowStreamAggPandasIterUDFSerializer` into
`ArrowStreamAggPandasUDFSerializer` for `SQL_GROUPED_AGG_PANDAS`.
Changes:
1. **Removed `ArrowStreamAggPandasIterUDFSerializer`** - The class was
nearly identical to `ArrowStreamAggPandasUDFSerializer`
2. **Unified serializer** - `ArrowStreamAggPandasUDFSerializer` now serves
`SQL_GROUPED_AGG_PANDAS_UDF`, `SQL_GROUPED_AGG_PANDAS_ITER_UDF`, and
`SQL_WINDOW_AGG_PANDAS_UDF`
3. **Added mapper for non-iter UDFs** - A new mapper in `worker.py` handles
batch concatenation for `SQL_GROUPED_AGG_PANDAS_UDF` and
`SQL_WINDOW_AGG_PANDAS_UDF`
### Why are the changes needed?
Similar to SPARK-54316, the two serializer classes had nearly identical
implementations:
- Identical `__init__` methods
- Same base class (`ArrowStreamPandasUDFSerializer`)
- Only `load_stream` differed slightly in output format
### Does this PR introduce _any_ user-facing change?
No. It's an internal refactor.
### How was this patch tested?
Existing unit tests:
- `python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py`
- `python/pyspark/sql/tests/pandas/test_pandas_udf_window.py`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53449 from
Yicong-Huang/SPARK-54589/refactor/consolidate-serde-for-grouped-agg-pandas.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 59 ++------------------------------
python/pyspark/worker.py | 42 ++++++++++++++++++-----
2 files changed, 37 insertions(+), 64 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 29e96d8a9123..fc86986e0fc0 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1135,7 +1135,8 @@ class
ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
return "ArrowStreamAggArrowUDFSerializer"
-# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF
+# Serializer for SQL_GROUPED_AGG_PANDAS_UDF, SQL_WINDOW_AGG_PANDAS_UDF,
+# and SQL_GROUPED_AGG_PANDAS_ITER_UDF
class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
def __init__(
self,
@@ -1156,60 +1157,6 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
- def load_stream(self, stream):
- """
- Deserialize Grouped ArrowRecordBatches and yield as a list of
pandas.Series.
- """
- import pyarrow as pa
-
- dataframes_in_group = None
-
- while dataframes_in_group is None or dataframes_in_group > 0:
- dataframes_in_group = read_int(stream)
-
- if dataframes_in_group == 1:
- yield (
- [
- self.arrow_to_pandas(c, i)
- for i, c in enumerate(
- pa.Table.from_batches(
- ArrowStreamSerializer.load_stream(self, stream)
- ).itercolumns()
- )
- ]
- )
-
- elif dataframes_in_group != 0:
- raise PySparkValueError(
- errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
- messageParameters={"dataframes_in_group":
str(dataframes_in_group)},
- )
-
- def __repr__(self):
- return "ArrowStreamAggPandasUDFSerializer"
-
-
-# Serializer for SQL_GROUPED_AGG_PANDAS_ITER_UDF
-class ArrowStreamAggPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
- def __init__(
- self,
- timezone,
- safecheck,
- assign_cols_by_name,
- int_to_decimal_coercion_enabled,
- ):
- super().__init__(
- timezone=timezone,
- safecheck=safecheck,
- assign_cols_by_name=assign_cols_by_name,
- df_for_struct=False,
- struct_in_pandas="dict",
- ndarray_as_list=False,
- arrow_cast=True,
- input_types=None,
- int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
- )
-
def load_stream(self, stream):
"""
Yield an iterator that produces one tuple of pandas.Series per batch.
@@ -1241,7 +1188,7 @@ class
ArrowStreamAggPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
)
def __repr__(self):
- return "ArrowStreamAggPandasIterUDFSerializer"
+ return "ArrowStreamAggPandasUDFSerializer"
# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index be1dcb215e2d..fee28b149f53 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -65,7 +65,6 @@ from pyspark.sql.pandas.serializers import (
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamArrowUDFSerializer,
ArrowStreamAggPandasUDFSerializer,
- ArrowStreamAggPandasIterUDFSerializer,
ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
@@ -2780,15 +2779,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
ser = ArrowStreamAggArrowUDFSerializer(
runner_conf.timezone, True, runner_conf.assign_cols_by_name,
True
)
- elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
- ser = ArrowStreamAggPandasIterUDFSerializer(
- runner_conf.timezone,
- runner_conf.safecheck,
- runner_conf.assign_cols_by_name,
- runner_conf.int_to_decimal_coercion_enabled,
- )
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
ser = ArrowStreamAggPandasUDFSerializer(
@@ -3365,6 +3358,39 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
else:
return result
+ elif eval_type in (
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+ ):
+ import pandas as pd
+
+ # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF,
+ # convert iterator of batch tuples to concatenated pandas Series
+ def mapper(batch_iter):
+ # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple
represents one batch
+ # Collect all batches and concatenate into single Series per column
+ batches = list(batch_iter)
+ if not batches:
+ # Empty batches - determine num_columns from all UDFs'
arg_offsets
+ all_offsets = [o for arg_offsets, _ in udfs for o in
arg_offsets]
+ num_columns = max(all_offsets) + 1 if all_offsets else 0
+ concatenated = [pd.Series(dtype=object) for _ in
range(num_columns)]
+ else:
+ # Use actual number of columns from the first batch
+ num_columns = len(batches[0])
+ concatenated = [
+ pd.concat([batch[i] for batch in batches],
ignore_index=True)
+ for i in range(num_columns)
+ ]
+
+ result = tuple(f(*[concatenated[o] for o in arg_offsets]) for
arg_offsets, f in udfs)
+ # In the special case of a single UDF this will return a single
result rather
+ # than a tuple of results; this is the format that the JVM side
expects.
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
+
else:
def mapper(a):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]