This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f601aa65c0dc [SPARK-53616][PYTHON] Introduce iterator API for pandas
grouped agg UDF
f601aa65c0dc is described below
commit f601aa65c0dc48a28afd1277a1aafe1b784f4eb6
Author: Yicong-Huang <[email protected]>
AuthorDate: Fri Dec 12 08:49:14 2025 +0800
[SPARK-53616][PYTHON] Introduce iterator API for pandas grouped agg UDF
### What changes were proposed in this pull request?
This PR introduces an iterator API for pandas grouped aggregation UDFs,
enabling batch-by-batch processing to improve memory efficiency. Users can now
write UDFs that accept `Iterator[pd.Series]` or `Iterator[Tuple[pd.Series,
...]]` and return a scalar value, allowing them to process large groups
incrementally without loading all data into memory at once.
This brings pandas UDFs to feature parity with Arrow UDFs, which already
support the iterator API.
### Why are the changes needed?
The iterator API provides better memory efficiency for grouped aggregation
UDFs by allowing batch-by-batch processing instead of loading all data for a
group into memory at once. This is especially beneficial for:
- Large groups that don't fit in memory
- Streaming or incremental processing scenarios
- Memory-constrained environments
### Does this PR introduce _any_ user-facing change?
Yes. This introduces a new way to write pandas grouped aggregation UDFs
using iterator-based type hints:
**Before:**
```python
pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
return v.mean()
```
**After (new iterator API):**
```python
pandas_udf("double")
def mean_iter(it: Iterator[pd.Series]) -> float:
sum_val = 0.0
cnt = 0
for v in it:
sum_val += v.sum()
cnt += len(v)
return sum_val / cnt
```
The iterator API is automatically detected via type hints
(`Iterator[pd.Series]` or `Iterator[Tuple[pd.Series, ...]]`), so existing code
continues to work unchanged.
### How was this patch tested?
- **Type Hint Tests**: Added tests in `test_pandas_udf_typehints.py`
verifying correct inference of iterator-based UDFs
- **Functional Tests**: Added tests in `test_pandas_udf_grouped_agg.py`
covering:
- Single column input (`Iterator[pd.Series]`)
- Multiple column input (`Iterator[Tuple[pd.Series, pd.Series]]`)
- Eval type verification
- Partial consumption scenarios
- **Documentation Tests**: All doctests pass, including new iterator API
examples
- **Integration**: Verified compatibility with existing grouped aggregation
UDFs
- All existing tests continue to pass
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53317 from
Yicong-Huang/SPARK-53616/feat/introduce-iterator-api-for-pandas-grouped-agg-udf.
Lead-authored-by: Yicong-Huang
<[email protected]>
Co-authored-by: Yicong Huang
<[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/api/python/PythonRunner.scala | 2 +
python/pyspark/sql/pandas/_typing/__init__.pyi | 1 +
python/pyspark/sql/pandas/functions.py | 67 +++++
python/pyspark/sql/pandas/functions.pyi | 2 +
python/pyspark/sql/pandas/serializers.py | 55 ++++
python/pyspark/sql/pandas/typehints.py | 50 +++-
.../tests/pandas/test_pandas_udf_grouped_agg.py | 292 ++++++++++++++++++---
.../sql/tests/pandas/test_pandas_udf_typehints.py | 90 +++++++
python/pyspark/util.py | 2 +
python/pyspark/worker.py | 60 +++++
.../python/ArrowAggregatePythonExec.scala | 4 +-
.../python/UserDefinedPythonFunction.scala | 2 +
12 files changed, 591 insertions(+), 36 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 63484c23a920..ccc61986d176 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -69,6 +69,7 @@ private[spark] object PythonEvalType {
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
val SQL_GROUPED_MAP_PANDAS_ITER_UDF = 216
+ val SQL_GROUPED_AGG_PANDAS_ITER_UDF = 217
// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
@@ -107,6 +108,7 @@ private[spark] object PythonEvalType {
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
case SQL_GROUPED_MAP_ARROW_ITER_UDF => "SQL_GROUPED_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_ITER_UDF => "SQL_GROUPED_MAP_PANDAS_ITER_UDF"
+ case SQL_GROUPED_AGG_PANDAS_ITER_UDF => "SQL_GROUPED_AGG_PANDAS_ITER_UDF"
// Arrow UDFs
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index b51c507ed4dd..cc3f599aca67 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -60,6 +60,7 @@ GroupedMapUDFTransformWithStateType = Literal[213]
GroupedMapUDFTransformWithStateInitStateType = Literal[214]
ArrowGroupedMapIterUDFType = Literal[215]
PandasGroupedMapIterUDFType = Literal[216]
+PandasGroupedAggIterUDFType = Literal[217]
# Arrow UDFs
ArrowScalarUDFType = Literal[250]
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index b6aaa865e929..7d8a90d57905 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -40,6 +40,8 @@ class PandasUDFType:
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
+ GROUPED_AGG_ITER = PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
+
class ArrowUDFType:
"""Arrow UDF Types. See :meth:`pyspark.sql.functions.arrow_udf`."""
@@ -415,6 +417,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. versionchanged:: 4.1.0
Supports iterator API in GROUPED_MAP type.
+ .. versionchanged:: 4.2.0
+ Supports iterator API in GROUPED_AGG type.
+
Parameters
----------
f : function, optional
@@ -673,6 +678,66 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Therefore, mutating the input series is not allowed and will cause
incorrect results.
For the same reason, users should also not rely on the index of
the input series.
+ * Iterator of Series to Scalar
+ `Iterator[pandas.Series]` -> `Any`
+
+ The function takes an iterator of `pandas.Series` and returns a scalar
value. This is
+ useful for grouped aggregations where the UDF can process all batches
for a group
+ iteratively, which is more memory-efficient than loading all data at
once. The returned
+ scalar can be a python primitive type or a numpy data type.
+
+ .. note:: Only a single UDF is supported per aggregation.
+
+ >>> from typing import Iterator
+ >>> @pandas_udf("double")
+ ... def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ ... sum_val = 0.0
+ ... cnt = 0
+ ... for v in it:
+ ... sum_val += v.sum()
+ ... cnt += len(v)
+ ... return sum_val / cnt
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id",
"v"))
+ >>> df.groupby("id").agg(pandas_mean_iter(df['v'])).show()
+ +---+-------------------+
+ | id|pandas_mean_iter(v)|
+ +---+-------------------+
+ | 1| 1.5|
+ | 2| 6.0|
+ +---+-------------------+
+
+ * Iterator of Multiple Series to Scalar
+ `Iterator[Tuple[pandas.Series, ...]]` -> `Any`
+
+ The function takes an iterator of a tuple of multiple `pandas.Series`
and returns a
+ scalar value. This is useful for grouped aggregations with multiple
input columns.
+
+ .. note:: Only a single UDF is supported per aggregation.
+
+ >>> from typing import Iterator, Tuple
+ >>> import numpy as np
+ >>> @pandas_udf("double")
+ ... def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ ... weighted_sum = 0.0
+ ... weight = 0.0
+ ... for v, w in it:
+ ... weighted_sum += np.dot(v, w)
+ ... weight += w.sum()
+ ... return weighted_sum / weight
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0),
(2, 10.0, 3.0)],
+ ... ("id", "v", "w"))
+ >>> df.groupby("id").agg(pandas_weighted_mean_iter(df["v"],
df["w"])).show()
+ +---+-------------------------------+
+ | id|pandas_weighted_mean_iter(v, w)|
+ +---+-------------------------------+
+ | 1| 1.6666666666666...|
+ | 2| 7.166666666666...|
+ +---+-------------------------------+
+
Notes
-----
The user-defined functions do not support conditional expressions or short
circuiting
@@ -761,6 +826,7 @@ def vectorized_udf(
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
@@ -823,6 +889,7 @@ def _validate_vectorized_udf(f, evalType, kind: str =
"pandas") -> int:
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
]:
warnings.warn(
"In Python 3.6+ and Spark 3.0+, it is preferred to specify type
hints for "
diff --git a/python/pyspark/sql/pandas/functions.pyi
b/python/pyspark/sql/pandas/functions.pyi
index d835208d02bf..ff587b4ecbe6 100644
--- a/python/pyspark/sql/pandas/functions.pyi
+++ b/python/pyspark/sql/pandas/functions.pyi
@@ -28,6 +28,7 @@ from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedAggFunction,
PandasGroupedAggUDFType,
+ PandasGroupedAggIterUDFType,
PandasGroupedMapFunction,
PandasGroupedMapIterUDFType,
PandasGroupedMapUDFType,
@@ -53,6 +54,7 @@ class PandasUDFType:
SCALAR_ITER: PandasScalarIterUDFType
GROUPED_MAP: PandasGroupedMapUDFType
GROUPED_AGG: PandasGroupedAggUDFType
+ GROUPED_AGG_ITER: PandasGroupedAggIterUDFType
class ArrowUDFType:
SCALAR: ArrowScalarUDFType
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 6a96aa3d2bf2..29e96d8a9123 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1189,6 +1189,61 @@ class
ArrowStreamAggPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
return "ArrowStreamAggPandasUDFSerializer"
+# Serializer for SQL_GROUPED_AGG_PANDAS_ITER_UDF
+class ArrowStreamAggPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ int_to_decimal_coercion_enabled,
+ ):
+ super().__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=assign_cols_by_name,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ndarray_as_list=False,
+ arrow_cast=True,
+ input_types=None,
+ int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ )
+
+ def load_stream(self, stream):
+ """
+ Yield an iterator that produces one tuple of pandas.Series per batch.
+ Each group yields Iterator[Tuple[pd.Series, ...]], allowing UDF to
+ process batches one by one without consuming all batches upfront.
+ """
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ # Lazily read and convert Arrow batches to pandas Series one
at a time
+ # from the stream. This avoids loading all batches into memory
for the group
+ batch_iter = (
+ tuple(self.arrow_to_pandas(c, i) for i, c in
enumerate(batch.columns))
+ for batch in ArrowStreamSerializer.load_stream(self,
stream)
+ )
+ yield batch_iter
+ # Make sure the batches are fully iterated before getting the
next group
+ for _ in batch_iter:
+ pass
+
+ elif dataframes_in_group != 0:
+ raise PySparkValueError(
+ errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
+ messageParameters={"dataframes_in_group":
str(dataframes_in_group)},
+ )
+
+ def __repr__(self):
+ return "ArrowStreamAggPandasIterUDFSerializer"
+
+
# Serializer for SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_PANDAS_ITER_UDF
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
def __init__(
diff --git a/python/pyspark/sql/pandas/typehints.py
b/python/pyspark/sql/pandas/typehints.py
index 8bb8d137be9a..4490748568e7 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
PandasScalarUDFType,
PandasScalarIterUDFType,
PandasGroupedAggUDFType,
+ PandasGroupedAggIterUDFType,
ArrowScalarUDFType,
ArrowScalarIterUDFType,
ArrowGroupedAggUDFType,
@@ -42,7 +43,14 @@ if TYPE_CHECKING:
def infer_pandas_eval_type(
sig: Signature,
type_hints: Dict[str, Any],
-) -> Optional[Union["PandasScalarUDFType", "PandasScalarIterUDFType",
"PandasGroupedAggUDFType"]]:
+) -> Optional[
+ Union[
+ "PandasScalarUDFType",
+ "PandasScalarIterUDFType",
+ "PandasGroupedAggUDFType",
+ "PandasGroupedAggIterUDFType",
+ ]
+]:
"""
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
:class:`inspect.Signature` instance and type hints.
@@ -152,6 +160,43 @@ def infer_pandas_eval_type(
if is_series_or_frame_agg:
return PandasUDFType.GROUPED_AGG
+ # Iterator[Tuple[Series, ...]] -> Any
+ is_iterator_tuple_series_agg = (
+ len(parameters_sig) == 1
+ and check_iterator_annotation( # Iterator
+ parameters_sig[0],
+ parameter_check_func=lambda a: check_tuple_annotation( # Tuple
+ a,
+ parameter_check_func=lambda ta: (ta == Ellipsis or ta ==
pd.Series),
+ ),
+ )
+ and (
+ return_annotation != pd.Series
+ and return_annotation != pd.DataFrame
+ and not check_iterator_annotation(return_annotation)
+ and not check_tuple_annotation(return_annotation)
+ )
+ )
+ if is_iterator_tuple_series_agg:
+ return PandasUDFType.GROUPED_AGG_ITER
+
+ # Iterator[Series] -> Any
+ is_iterator_series_agg = (
+ len(parameters_sig) == 1
+ and check_iterator_annotation(
+ parameters_sig[0],
+ parameter_check_func=lambda a: a == pd.Series,
+ )
+ and (
+ return_annotation != pd.Series
+ and return_annotation != pd.DataFrame
+ and not check_iterator_annotation(return_annotation)
+ and not check_tuple_annotation(return_annotation)
+ )
+ )
+ if is_iterator_series_agg:
+ return PandasUDFType.GROUPED_AGG_ITER
+
return None
@@ -289,6 +334,7 @@ def infer_eval_type(
"PandasScalarUDFType",
"PandasScalarIterUDFType",
"PandasGroupedAggUDFType",
+ "PandasGroupedAggIterUDFType",
"ArrowScalarUDFType",
"ArrowScalarIterUDFType",
"ArrowGroupedAggUDFType",
@@ -305,6 +351,7 @@ def infer_eval_type(
"PandasScalarUDFType",
"PandasScalarIterUDFType",
"PandasGroupedAggUDFType",
+ "PandasGroupedAggIterUDFType",
"ArrowScalarUDFType",
"ArrowScalarIterUDFType",
"ArrowGroupedAggUDFType",
@@ -337,6 +384,7 @@ def infer_eval_type_for_udf( # type: ignore[no-untyped-def]
"PandasScalarUDFType",
"PandasScalarIterUDFType",
"PandasGroupedAggUDFType",
+ "PandasGroupedAggIterUDFType",
"ArrowScalarUDFType",
"ArrowScalarIterUDFType",
"ArrowGroupedAggUDFType",
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 4b66dee5b7af..c71e78609f82 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -17,7 +17,7 @@
import unittest
import logging
-from typing import cast
+from typing import cast, Iterator, Tuple
from pyspark.util import PythonEvalType, is_remote_only
from pyspark.sql import Row, functions as sf
@@ -27,7 +27,6 @@ from pyspark.sql.functions import (
col,
lit,
mean,
- sum,
udf,
pandas_udf,
PandasUDFType,
@@ -246,16 +245,16 @@ class GroupedAggPandasUDFTestsMixin:
# Mix group aggregate pandas UDF with sql expression
result1 = df.groupby("id").agg(sum_udf(df.v) + 1).sort("id")
- expected1 = df.groupby("id").agg(sum(df.v) + 1).sort("id")
+ expected1 = df.groupby("id").agg(sf.sum(df.v) + 1).sort("id")
# Mix group aggregate pandas UDF with sql expression (order swapped)
result2 = df.groupby("id").agg(sum_udf(df.v + 1)).sort("id")
- expected2 = df.groupby("id").agg(sum(df.v + 1)).sort("id")
+ expected2 = df.groupby("id").agg(sf.sum(df.v + 1)).sort("id")
# Wrap group aggregate pandas UDF with two sql expressions
result3 = df.groupby("id").agg(sum_udf(df.v + 1) + 2).sort("id")
- expected3 = df.groupby("id").agg(sum(df.v + 1) + 2).sort("id")
+ expected3 = df.groupby("id").agg(sf.sum(df.v + 1) + 2).sort("id")
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -272,26 +271,26 @@ class GroupedAggPandasUDFTestsMixin:
# Mix group aggregate pandas UDF and python UDF
result1 = df.groupby("id").agg(plus_one(sum_udf(df.v))).sort("id")
- expected1 = df.groupby("id").agg(plus_one(sum(df.v))).sort("id")
+ expected1 = df.groupby("id").agg(plus_one(sf.sum(df.v))).sort("id")
# Mix group aggregate pandas UDF and python UDF (order swapped)
result2 = df.groupby("id").agg(sum_udf(plus_one(df.v))).sort("id")
- expected2 = df.groupby("id").agg(sum(plus_one(df.v))).sort("id")
+ expected2 = df.groupby("id").agg(sf.sum(plus_one(df.v))).sort("id")
# Mix group aggregate pandas UDF and scalar pandas UDF
result3 = df.groupby("id").agg(sum_udf(plus_two(df.v))).sort("id")
- expected3 = df.groupby("id").agg(sum(plus_two(df.v))).sort("id")
+ expected3 = df.groupby("id").agg(sf.sum(plus_two(df.v))).sort("id")
# Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
result4 = df.groupby("id").agg(plus_two(sum_udf(df.v))).sort("id")
- expected4 = df.groupby("id").agg(plus_two(sum(df.v))).sort("id")
+ expected4 = df.groupby("id").agg(plus_two(sf.sum(df.v))).sort("id")
# Wrap group aggregate pandas UDF with two python UDFs and use python
UDF in groupby
result5 = (
df.groupby(plus_one(df.id)).agg(plus_one(sum_udf(plus_one(df.v)))).sort("plus_one(id)")
)
expected5 = (
-
df.groupby(plus_one(df.id)).agg(plus_one(sum(plus_one(df.v)))).sort("plus_one(id)")
+
df.groupby(plus_one(df.id)).agg(plus_one(sf.sum(plus_one(df.v)))).sort("plus_one(id)")
)
# Wrap group aggregate pandas UDF with two scala pandas UDF and user
scala pandas UDF in
@@ -300,7 +299,7 @@ class GroupedAggPandasUDFTestsMixin:
df.groupby(plus_two(df.id)).agg(plus_two(sum_udf(plus_two(df.v)))).sort("plus_two(id)")
)
expected6 = (
-
df.groupby(plus_two(df.id)).agg(plus_two(sum(plus_two(df.v)))).sort("plus_two(id)")
+
df.groupby(plus_two(df.id)).agg(plus_two(sf.sum(plus_two(df.v)))).sort("plus_two(id)")
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -327,7 +326,7 @@ class GroupedAggPandasUDFTestsMixin:
)
expected1 = (
df.groupBy("id")
- .agg(mean(df.v), sum(df.v), mean(df.v).alias("weighted_mean(v,
w)"))
+ .agg(mean(df.v), sf.sum(df.v), mean(df.v).alias("weighted_mean(v,
w)"))
.sort("id")
.toPandas()
)
@@ -342,23 +341,23 @@ class GroupedAggPandasUDFTestsMixin:
# groupby one expression
result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
- expected1 = df.groupby(df.v % 2).agg(sum(df.v))
+ expected1 = df.groupby(df.v % 2).agg(sf.sum(df.v))
# empty groupby
result2 = df.groupby().agg(sum_udf(df.v))
- expected2 = df.groupby().agg(sum(df.v))
+ expected2 = df.groupby().agg(sf.sum(df.v))
# groupby one column and one sql expression
result3 = df.groupby(df.id, df.v %
2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
- expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id,
df.v % 2)
+ expected3 = df.groupby(df.id, df.v %
2).agg(sf.sum(df.v)).orderBy(df.id, df.v % 2)
# groupby one python UDF
result4 =
df.groupby(plus_one(df.id)).agg(sum_udf(df.v)).sort("plus_one(id)")
- expected4 =
df.groupby(plus_one(df.id)).agg(sum(df.v)).sort("plus_one(id)")
+ expected4 =
df.groupby(plus_one(df.id)).agg(sf.sum(df.v)).sort("plus_one(id)")
# groupby one scalar pandas UDF
result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)).sort("sum(v)")
- expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)).sort("sum(v)")
+ expected5 =
df.groupby(plus_two(df.id)).agg(sf.sum(df.v)).sort("sum(v)")
# groupby one expression and one python UDF
result6 = (
@@ -367,7 +366,9 @@ class GroupedAggPandasUDFTestsMixin:
.sort(["(v % 2)", "plus_one(id)"])
)
expected6 = (
- df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v)).sort(["(v %
2)", "plus_one(id)"])
+ df.groupby(df.v % 2, plus_one(df.id))
+ .agg(sf.sum(df.v))
+ .sort(["(v % 2)", "plus_one(id)"])
)
# groupby one expression and one scalar pandas UDF
@@ -377,7 +378,7 @@ class GroupedAggPandasUDFTestsMixin:
.sort(["sum(v)", "plus_two(id)"])
)
expected7 = (
- df.groupby(df.v % 2,
plus_two(df.id)).agg(sum(df.v)).sort(["sum(v)", "plus_two(id)"])
+ df.groupby(df.v % 2,
plus_two(df.id)).agg(sf.sum(df.v)).sort(["sum(v)", "plus_two(id)"])
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -417,11 +418,11 @@ class GroupedAggPandasUDFTestsMixin:
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
- sum(col("v")),
- sum(col("v1") + 3),
- sum(col("v2")) + 5,
- plus_one(sum(col("v1"))),
- sum(plus_one(col("v2"))),
+ sf.sum(col("v")),
+ sf.sum(col("v1") + 3),
+ sf.sum(col("v2")) + 5,
+ plus_one(sf.sum(col("v1"))),
+ sf.sum(plus_one(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
@@ -451,11 +452,11 @@ class GroupedAggPandasUDFTestsMixin:
.withColumn("v2", df.v + 2)
.groupby(df.id, df.v % 2)
.agg(
- sum(col("v")),
- sum(col("v1") + 3),
- sum(col("v2")) + 5,
- plus_two(sum(col("v1"))),
- sum(plus_two(col("v2"))),
+ sf.sum(col("v")),
+ sf.sum(col("v1") + 3),
+ sf.sum(col("v2")) + 5,
+ plus_two(sf.sum(col("v1"))),
+ sf.sum(plus_two(col("v2"))),
)
.sort(["id", "(v % 2)"])
.toPandas()
@@ -474,9 +475,9 @@ class GroupedAggPandasUDFTestsMixin:
expected3 = (
df.groupby("id")
- .agg(sum(df.v).alias("v"))
+ .agg(sf.sum(df.v).alias("v"))
.groupby("id")
- .agg(sum(col("v")))
+ .agg(sf.sum(col("v")))
.sort("id")
.toPandas()
)
@@ -491,7 +492,7 @@ class GroupedAggPandasUDFTestsMixin:
sum_udf = self.pandas_agg_sum_udf
result1 = df.groupby(df.id).agg(sum_udf(df.v))
- expected1 = df.groupby(df.id).agg(sum(df.v))
+ expected1 = df.groupby(df.id).agg(sf.sum(df.v))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_array_type(self):
@@ -706,7 +707,7 @@ class GroupedAggPandasUDFTestsMixin:
):
with self.subTest(with_w=False, query_no=i):
assertDataFrameEqual(
- aggregated, df.groupby("id").agg((sum(df.v) +
lit(100)).alias("s"))
+ aggregated, df.groupby("id").agg((sf.sum(df.v) +
lit(100)).alias("s"))
)
# with "w"
@@ -722,7 +723,7 @@ class GroupedAggPandasUDFTestsMixin:
):
with self.subTest(with_w=True, query_no=i):
assertDataFrameEqual(
- aggregated, df.groupby("id").agg((sum(df.v) +
sum(df.w)).alias("s"))
+ aggregated, df.groupby("id").agg((sf.sum(df.v) +
sf.sum(df.w)).alias("s"))
)
def test_arrow_cast_enabled_numeric_to_decimal(self):
@@ -897,6 +898,229 @@ class GroupedAggPandasUDFTestsMixin:
)
assert_frame_equal(expected, result)
+ def test_iterator_grouped_agg_basic(self):
+ """
+ Test basic functionality of iterator grouped agg pandas UDF with
Iterator[pd.Series].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ @pandas_udf("double")
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ assert isinstance(series, pd.Series)
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ result =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id").collect()
+
+ # Expected means:
+ # Group 1: (1.0 + 2.0) / 2 = 1.5
+ # Group 2: (3.0 + 5.0 + 10.0) / 3 = 6.0
+ expected = [Row(id=1, mean=1.5), Row(id=2, mean=6.0)]
+ self.assertEqual(result, expected)
+
+ def test_iterator_grouped_agg_multiple_columns(self):
+ """
+ Test iterator grouped agg pandas UDF with multiple columns
+ using Iterator[Tuple[pd.Series, ...]].
+ """
+ df = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ @pandas_udf("double")
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ assert isinstance(v_series, pd.Series)
+ assert isinstance(w_series, pd.Series)
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ result = (
+ df.groupby("id")
+ .agg(pandas_weighted_mean_iter(df["v"], df["w"]).alias("wm"))
+ .sort("id")
+ .collect()
+ )
+
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
+ self.assertEqual(result, expected)
+
+ def test_iterator_grouped_agg_eval_type(self):
+ """
+ Test that the eval type is correctly inferred for iterator grouped agg
UDFs.
+ """
+
+ @pandas_udf("double")
+ def pandas_sum_iter(it: Iterator[pd.Series]) -> float:
+ total = 0.0
+ for series in it:
+ total += series.sum()
+ return total
+
+ self.assertEqual(pandas_sum_iter.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF)
+
+ @pandas_udf("double")
+ def pandas_sum_iter_tuple(it: Iterator[Tuple[pd.Series, pd.Series]])
-> float:
+ total = 0.0
+ for v, w in it:
+ total += v.sum()
+ return total
+
+ self.assertEqual(
+ pandas_sum_iter_tuple.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
+ )
+
+ def test_iterator_grouped_agg_partial_consumption(self):
+ """
+ Test that iterator grouped agg UDF can partially consume batches.
+ This ensures that batches are processed one by one without loading all
data into memory.
+ """
+ # Create a dataset with multiple batches per group
+ # Use small batch size to ensure multiple batches per group
+ # Use same value (1.0) for all records to avoid batch ordering issues
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch":
2}):
+ # Group 1: 6 values (3 batches) - will process only first 2
batches (partial)
+ # Group 2: 2 values (1 batch) - will process 1 batch (all
available)
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0), (1, 1.0),
(2, 1.0), (2, 1.0)],
+ ("id", "v"),
+ )
+
+ @pandas_udf("long")
+ def pandas_partial_count(it: Iterator[pd.Series]) -> int:
+ # Process first 2 batches, then stop (partial consumption)
+ total_count = 0
+ for i, series in enumerate(it):
+ assert isinstance(series, pd.Series)
+ if i < 2: # Process first 2 batches
+ total_count += len(series)
+ else:
+ # Stop early - partial consumption
+ break
+ return total_count
+
+ result =
df.groupby("id").agg(pandas_partial_count(df["v"]).alias("count")).sort("id")
+
+ # Verify results are correct for partial consumption
+ # With batch size = 2:
+ # Group 1 (id=1): 6 values in 3 batches -> processes only first 2
batches (partial)
+ # Result: count=4 (only 4 out of 6 values processed)
+ # Group 2 (id=2): 2 values in 1 batch -> processes 1 batch (all
available)
+ # Result: count=2
+ actual = result.collect()
+ self.assertEqual(len(actual), 2, "Should have results for both
groups")
+
+ # Verify partial consumption works
+ # Group 1: processes only 2 batches (4 values out of 6 total) -
partial consumption
+ group1_result = next(row for row in actual if row["id"] == 1)
+ self.assertEqual(
+ group1_result["count"], 4, msg="Group 1 should process only 2
batches (4 values)"
+ )
+
+ # Group 2: processes 1 batch (all 2 values, 1 batch available)
+ group2_result = next(row for row in actual if row["id"] == 2)
+ self.assertEqual(
+ group2_result["count"], 2, msg="Group 2 should process 1 batch
(2 values)"
+ )
+
+ def test_grouped_agg_with_struct_type_input(self):
+ """
+ Test that grouped agg UDF works with struct type input.
+ Struct types should be passed as pd.DataFrame to the UDF (similar to
scalar pandas UDFs).
+ """
+ from pyspark.sql import Row
+
+ # Create a DataFrame with struct column
+ df = self.spark.createDataFrame(
+ [
+ (1, Row(name="Alice", age=25)),
+ (1, Row(name="Bob", age=30)),
+ (2, Row(name="Charlie", age=35)),
+ (2, Row(name="David", age=40)),
+ ],
+ "id int, person struct<name:string,age:int>",
+ )
+
+ # Test non-iterator grouped agg UDF with struct input
+ # Note: Currently struct types are passed as Series of dicts when
df_for_struct=False.
+ # This test verifies the behavior and documents the expected interface.
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def avg_age(person: pd.Series) -> float:
+ # Currently struct types are passed as Series of dicts
+ # In the future, they should be passed as pd.DataFrame (like
scalar pandas UDFs)
+ assert isinstance(person, pd.Series), f"Expected Series, got
{type(person)}"
+ # Extract age values from dicts
+ ages = [p["age"] for p in person]
+ return sum(ages) / len(ages) if ages else 0.0
+
+ result =
df.groupby("id").agg(avg_age(df["person"]).alias("avg_age")).sort("id")
+ actual = result.collect()
+
+ # Group 1: (25 + 30) / 2 = 27.5
+ # Group 2: (35 + 40) / 2 = 37.5
+ expected = [Row(id=1, avg_age=27.5), Row(id=2, avg_age=37.5)]
+ self.assertEqual(actual, expected)
+
+ def test_iterator_grouped_agg_with_struct_type_input(self):
+ """
+ Test that iterator grouped agg UDF works with struct type input.
+ Struct types should be passed as pd.DataFrame to the UDF (similar to
scalar pandas UDFs).
+ """
+ from pyspark.sql import Row
+
+ # Create a DataFrame with struct column
+ df = self.spark.createDataFrame(
+ [
+ (1, Row(name="Alice", age=25)),
+ (1, Row(name="Bob", age=30)),
+ (2, Row(name="Charlie", age=35)),
+ (2, Row(name="David", age=40)),
+ ],
+ "id int, person struct<name:string,age:int>",
+ )
+
+ # Test iterator grouped agg UDF with struct input
+ # Note: Currently struct types are passed as Series of dicts when
df_for_struct=False.
+ # This test verifies the behavior and documents the expected interface.
+ @pandas_udf("double")
+ def avg_age_iter(it: Iterator[pd.Series]) -> float:
+ total_age = 0.0
+ count = 0
+ for person_series in it:
+ # Currently struct types are passed as Series of dicts
+ # In the future, they should be passed as pd.DataFrame (like
scalar pandas UDFs)
+ assert isinstance(
+ person_series, pd.Series
+ ), f"Expected Series, got {type(person_series)}"
+ # Extract age values from dicts
+ ages = [p["age"] for p in person_series]
+ total_age += sum(ages)
+ count += len(ages)
+ return total_age / count if count > 0 else 0.0
+
+ result =
df.groupby("id").agg(avg_age_iter(df["person"]).alias("avg_age")).sort("id")
+ actual = result.collect()
+
+ # Group 1: (25 + 30) / 2 = 27.5
+ # Group 2: (35 + 40) / 2 = 37.5
+ expected = [Row(id=1, avg_age=27.5), Row(id=2, avg_age=37.5)]
+ self.assertEqual(actual, expected)
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index d4c1cf5d62ff..e3d87ededaa7 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -187,6 +187,47 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
PandasUDFType.GROUPED_AGG,
)
+ def test_type_annotation_group_agg_iter(self):
+ # Iterator[pd.Series] -> Any
+ def func(iter: Iterator[pd.Series]) -> float:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.GROUPED_AGG_ITER
+ )
+
+ # Iterator[Tuple[pd.Series, pd.Series]] -> Any
+ def func(iter: Iterator[Tuple[pd.Series, pd.Series]]) -> int:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.GROUPED_AGG_ITER
+ )
+
+ # Iterator[Tuple[pd.Series, ...]] -> Any
+ def func(iter: Iterator[Tuple[pd.Series, ...]]) -> str:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.GROUPED_AGG_ITER
+ )
+
+ # Union[pd.Series, pd.Series] equals to pd.Series
+ def func(iter: Iterator[Union[pd.Series, pd.Series]]) -> float:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.GROUPED_AGG_ITER
+ )
+
+ # Iterator[tuple[pd.Series, pd.Series]] -> Any
+ def func(iter: Iterator[tuple[pd.Series, pd.Series]]) -> float:
+ pass
+
+ self.assertEqual(
+ infer_eval_type(signature(func), get_type_hints(func)),
PandasUDFType.GROUPED_AGG_ITER
+ )
+
def test_type_annotation_group_map(self):
# pd.DataFrame -> pd.DataFrame
def func(col: pd.DataFrame) -> pd.DataFrame:
@@ -344,6 +385,55 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
expected = df.groupby("id").agg(mean(df.v).alias("weighted_mean(v,
1.0)")).sort("id")
assert_frame_equal(expected.toPandas(), actual.toPandas())
+ def test_group_agg_iter_udf_type_hint(self):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def pandas_mean_iter(it: Iterator[pd.Series]) -> float:
+ sum_val = 0.0
+ cnt = 0
+ for series in it:
+ sum_val += series.sum()
+ cnt += len(series)
+ return sum_val / cnt if cnt > 0 else 0.0
+
+ pandas_mean_iter = pandas_udf("double")(pandas_mean_iter)
+
+ actual =
df.groupby("id").agg(pandas_mean_iter(df["v"]).alias("mean")).sort("id")
+ expected = df.groupby("id").agg(mean(df["v"]).alias("mean")).sort("id")
+ assert_frame_equal(expected.toPandas(), actual.toPandas())
+
+ # Test with Tuple for multiple columns
+ df2 = self.spark.createDataFrame(
+ [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2,
10.0, 3.0)],
+ ("id", "v", "w"),
+ )
+
+ def pandas_weighted_mean_iter(it: Iterator[Tuple[pd.Series,
pd.Series]]) -> float:
+ import numpy as np
+
+ weighted_sum = 0.0
+ weight = 0.0
+ for v_series, w_series in it:
+ weighted_sum += np.dot(v_series, w_series)
+ weight += w_series.sum()
+ return weighted_sum / weight if weight > 0 else 0.0
+
+ pandas_weighted_mean_iter =
pandas_udf("double")(pandas_weighted_mean_iter)
+
+ actual2 = (
+ df2.groupby("id")
+ .agg(pandas_weighted_mean_iter(df2["v"], df2["w"]).alias("wm"))
+ .sort("id")
+ )
+ # Expected weighted means:
+ # Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
+ # Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 /
6.0
+ expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
+ actual_results = actual2.collect()
+ self.assertEqual(actual_results, expected)
+
def test_ignore_type_hint_in_group_apply_in_pandas(self):
df = self.spark.range(10)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 22c653508fbb..a9c36fb2ae6b 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -65,6 +65,7 @@ if typing.TYPE_CHECKING:
ArrowGroupedMapIterUDFType,
ArrowCogroupedMapUDFType,
PandasGroupedMapIterUDFType,
+ PandasGroupedAggIterUDFType,
PandasGroupedMapUDFTransformWithStateType,
PandasGroupedMapUDFTransformWithStateInitStateType,
GroupedMapUDFTransformWithStateType,
@@ -653,6 +654,7 @@ class PythonEvalType:
)
SQL_GROUPED_MAP_ARROW_ITER_UDF: "ArrowGroupedMapIterUDFType" = 215
SQL_GROUPED_MAP_PANDAS_ITER_UDF: "PandasGroupedMapIterUDFType" = 216
+ SQL_GROUPED_AGG_PANDAS_ITER_UDF: "PandasGroupedAggIterUDFType" = 217
# Arrow UDFs
SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e604afab51c9..be1dcb215e2d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -65,6 +65,7 @@ from pyspark.sql.pandas.serializers import (
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamArrowUDFSerializer,
ArrowStreamAggPandasUDFSerializer,
+ ArrowStreamAggPandasIterUDFSerializer,
ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
@@ -1094,6 +1095,28 @@ def wrap_grouped_agg_arrow_iter_udf(f, args_offsets,
kwargs_offsets, return_type
)
+def wrap_grouped_agg_pandas_iter_udf(f, args_offsets, kwargs_offsets,
return_type, runner_conf):
+ func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets,
kwargs_offsets)
+
+ arrow_return_type = to_arrow_type(
+ return_type, prefers_large_types=runner_conf.use_large_var_types
+ )
+
+ def wrapped(series_iter):
+ import pandas as pd
+
+ # series_iter: Iterator[pd.Series] (single column) or
+ # Iterator[Tuple[pd.Series, ...]] (multiple columns)
+ # This has already been adapted by the mapper function in read_udfs
+ result = func(series_iter)
+ return pd.Series([result])
+
+ return (
+ args_kwargs_offsets,
+ lambda *a: (wrapped(*a), arrow_return_type),
+ )
+
+
def wrap_window_agg_pandas_udf(
f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
):
@@ -1398,6 +1421,7 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
):
args_offsets = []
kwargs_offsets = {}
@@ -1507,6 +1531,10 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
return wrap_grouped_agg_arrow_iter_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf
)
+ elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
+ return wrap_grouped_agg_pandas_iter_udf(
+ func, args_offsets, kwargs_offsets, return_type, runner_conf
+ )
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
return wrap_window_agg_pandas_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf,
udf_index
@@ -2713,6 +2741,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
@@ -2751,6 +2780,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
ser = ArrowStreamAggArrowUDFSerializer(
runner_conf.timezone, True, runner_conf.assign_cols_by_name,
True
)
+ elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
+ ser = ArrowStreamAggPandasIterUDFSerializer(
+ runner_conf.timezone,
+ runner_conf.safecheck,
+ runner_conf.assign_cols_by_name,
+ runner_conf.int_to_decimal_coercion_enabled,
+ )
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
@@ -3267,6 +3303,30 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf):
batch_iter = (tuple(batch_columns[o] for o in arg_offsets) for
batch_columns in a)
return f(batch_iter)
+ elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF:
+ # We assume there is only one UDF here because grouped agg doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ arg_offsets, f = udfs[0]
+
+ # Convert to iterator of pandas Series:
+ # - Iterator[pd.Series] for single column
+ # - Iterator[Tuple[pd.Series, ...]] for multiple columns
+ def mapper(batch_iter):
+ # batch_iter is Iterator[Tuple[pd.Series, ...]] where each tuple
represents one batch
+ # Convert to Iterator[pd.Series] or Iterator[Tuple[pd.Series,
...]] based on arg_offsets
+ if len(arg_offsets) == 1:
+ # Single column: Iterator[Tuple[pd.Series, ...]] ->
Iterator[pd.Series]
+ series_iter = (batch_series[arg_offsets[0]] for batch_series
in batch_iter)
+ else:
+ # Multiple columns: Iterator[Tuple[pd.Series, ...]] ->
+ # Iterator[Tuple[pd.Series, ...]]
+ series_iter = (
+ tuple(batch_series[o] for o in arg_offsets) for
batch_series in batch_iter
+ )
+ return f(series_iter)
+
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
index 48eca0896292..3e76008e2e61 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala
@@ -42,6 +42,7 @@ import org.apache.spark.util.Utils
* <li> SQL_GROUPED_AGG_ARROW_UDF for Arrow UDF
* <li> SQL_GROUPED_AGG_ARROW_ITER_UDF for Arrow UDF with iterator API
* <li> SQL_GROUPED_AGG_PANDAS_UDF for Pandas UDF
+ * <li> SQL_GROUPED_AGG_PANDAS_ITER_UDF for Pandas UDF with iterator API
* </ul>
*
* This plan works by sending the necessary (projected) input grouped data as
Arrow record batches
@@ -240,7 +241,8 @@ case class ArrowAggregatePythonExec(
Array(
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
- PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF)
}
object ArrowAggregatePythonExec {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index bcbc12573099..41bffaca65cf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -50,6 +50,7 @@ case class UserDefinedPythonFunction(
|| pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
|| pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
+ || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
|| pythonEvalType == PythonEvalType.SQL_SCALAR_ARROW_UDF
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF) {
@@ -64,6 +65,7 @@ case class UserDefinedPythonFunction(
}
if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
+ || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
|| pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF) {
PythonUDAF(name, func, dataType, e, udfDeterministic, pythonEvalType)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]