This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 25c57d87382c [SPARK-54589][PYTHON] Consolidate 
ArrowStreamAggPandasIterUDFSerializer into ArrowStreamAggPandasUDFSerializer
25c57d87382c is described below

commit 25c57d87382ca90850541d4d20447bfd842ec85a
Author: Yicong-Huang <[email protected]>
AuthorDate: Tue Dec 16 12:32:11 2025 +0800

    [SPARK-54589][PYTHON] Consolidate ArrowStreamAggPandasIterUDFSerializer 
into ArrowStreamAggPandasUDFSerializer
    
    ### What changes were proposed in this pull request?
    
    This PR consolidates `ArrowStreamAggPandasIterUDFSerializer` into 
`ArrowStreamAggPandasUDFSerializer` for `SQL_GROUPED_AGG_PANDAS`.
    
    Changes:
    1. **Removed `ArrowStreamAggPandasIterUDFSerializer`** - The class was 
nearly identical to `ArrowStreamAggPandasUDFSerializer`
    2. **Unified serializer** - `ArrowStreamAggPandasUDFSerializer` now serves 
`SQL_GROUPED_AGG_PANDAS_UDF`, `SQL_GROUPED_AGG_PANDAS_ITER_UDF`, and 
`SQL_WINDOW_AGG_PANDAS_UDF`
    3. **Added mapper for non-iter UDFs** - A new mapper in `worker.py` handles 
batch concatenation for `SQL_GROUPED_AGG_PANDAS_UDF` and 
`SQL_WINDOW_AGG_PANDAS_UDF`
    
    ### Why are the changes needed?
    
    Similar to SPARK-54316, the two serializer classes had nearly identical 
implementations:
    - Identical `__init__` methods
    - Same base class (`ArrowStreamPandasUDFSerializer`)
    - Only `load_stream` differed slightly in output format
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. It's an internal refactor.
    
    ### How was this patch tested?
    
    Existing unit tests:
    - `python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py`
    - `python/pyspark/sql/tests/pandas/test_pandas_udf_window.py`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53449 from 
Yicong-Huang/SPARK-54589/refactor/consolidate-serde-for-grouped-agg-pandas.
    
    Authored-by: Yicong-Huang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py | 59 ++------------------------------
 python/pyspark/worker.py                 | 42 ++++++++++++++++++-----
 2 files changed, 37 insertions(+), 64 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 29e96d8a9123..fc86986e0fc0 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1135,7 +1135,8 @@ class 
ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
         return "ArrowStreamAggArrowUDFSerializer"
 
 
-# Serializer for SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF
+# Serializer for SQL_GROUPED_AGG_PANDAS_UDF, SQL_WINDOW_AGG_PANDAS_UDF,
+# and SQL_GROUPED_AGG_PANDAS_ITER_UDF
 class ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
     def __init__(
         self,
@@ -1156,60 +1157,6 @@ class 
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
             int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
         )
 
-    def load_stream(self, stream):
-        """
-        Deserialize Grouped ArrowRecordBatches and yield as a list of 
pandas.Series.
-        """
-        import pyarrow as pa
-
-        dataframes_in_group = None
-
-        while dataframes_in_group is None or dataframes_in_group > 0:
-            dataframes_in_group = read_int(stream)
-
-            if dataframes_in_group == 1:
-                yield (
-                    [
-                        self.arrow_to_pandas(c, i)
-                        for i, c in enumerate(
-                            pa.Table.from_batches(
-                                ArrowStreamSerializer.load_stream(self, stream)
-                            ).itercolumns()
-                        )
-                    ]
-                )
-
-            elif dataframes_in_group != 0:
-                raise PySparkValueError(
-                    errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
-                    messageParameters={"dataframes_in_group": 
str(dataframes_in_group)},
-                )
-
-    def __repr__(self):
-        return "ArrowStreamAggPandasUDFSerializer"
-
-
-# Serializer for SQL_GROUPED_AGG_PANDAS_ITER_UDF
-class ArrowStreamAggPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
-    def __init__(
-        self,
-        timezone,
-        safecheck,
-        assign_cols_by_name,
-        int_to_decimal_coercion_enabled,
-    ):
-        super().__init__(
-            timezone=timezone,
-            safecheck=safecheck,
-            assign_cols_by_name=assign_cols_by_name,
-            df_for_struct=False,
-            struct_in_pandas="dict",
-            ndarray_as_list=False,
-            arrow_cast=True,
-            input_types=None,
-            int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
-        )
-
     def load_stream(self, stream):
         """
         Yield an iterator that produces one tuple of pandas.Series per batch.
@@ -1241,7 +1188,7 @@ class 
ArrowStreamAggPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
                 )
 
     def __repr__(self):
-        return "ArrowStreamAggPandasIterUDFSerializer"
+        return "ArrowStreamAggPandasUDFSerializer"
 
 
 # Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index be1dcb215e2d..fee28b149f53 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -65,7 +65,6 @@ from pyspark.sql.pandas.serializers import (
     TransformWithStateInPySparkRowInitStateSerializer,
     ArrowStreamArrowUDFSerializer,
     ArrowStreamAggPandasUDFSerializer,
-    ArrowStreamAggPandasIterUDFSerializer,
     ArrowStreamAggArrowUDFSerializer,
     ArrowBatchUDFSerializer,
     ArrowStreamUDTFSerializer,
@@ -2780,15 +2779,9 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
             ser = ArrowStreamAggArrowUDFSerializer(
                 runner_conf.timezone, True, runner_conf.assign_cols_by_name, 
True
             )
-        elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
-            ser = ArrowStreamAggPandasIterUDFSerializer(
-                runner_conf.timezone,
-                runner_conf.safecheck,
-                runner_conf.assign_cols_by_name,
-                runner_conf.int_to_decimal_coercion_enabled,
-            )
         elif eval_type in (
             PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+            PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
             PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         ):
             ser = ArrowStreamAggPandasUDFSerializer(
@@ -3365,6 +3358,39 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
             else:
                 return result
 
+    elif eval_type in (
+        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+    ):
+        import pandas as pd
+
+        # For SQL_GROUPED_AGG_PANDAS_UDF and SQL_WINDOW_AGG_PANDAS_UDF,
+        # convert iterator of batch tuples to concatenated pandas Series
+        def mapper(batch_iter):
+            # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple 
represents one batch
+            # Collect all batches and concatenate into single Series per column
+            batches = list(batch_iter)
+            if not batches:
+                # Empty batches - determine num_columns from all UDFs' 
arg_offsets
+                all_offsets = [o for arg_offsets, _ in udfs for o in 
arg_offsets]
+                num_columns = max(all_offsets) + 1 if all_offsets else 0
+                concatenated = [pd.Series(dtype=object) for _ in 
range(num_columns)]
+            else:
+                # Use actual number of columns from the first batch
+                num_columns = len(batches[0])
+                concatenated = [
+                    pd.concat([batch[i] for batch in batches], 
ignore_index=True)
+                    for i in range(num_columns)
+                ]
+
+            result = tuple(f(*[concatenated[o] for o in arg_offsets]) for 
arg_offsets, f in udfs)
+            # In the special case of a single UDF this will return a single 
result rather
+            # than a tuple of results; this is the format that the JVM side 
expects.
+            if len(result) == 1:
+                return result[0]
+            else:
+                return result
+
     else:
 
         def mapper(a):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to