This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4202f239c45a [SPARK-53614][PYTHON] Add `Iterator[pandas.DataFrame]`
support to `applyInPandas`
4202f239c45a is described below
commit 4202f239c45a290e340bcd505de849d876f992fa
Author: Yicong-Huang <[email protected]>
AuthorDate: Fri Oct 31 14:33:05 2025 +0800
[SPARK-53614][PYTHON] Add `Iterator[pandas.DataFrame]` support to
`applyInPandas`
### What changes were proposed in this pull request?
This PR adds support for the `Iterator[pandas.DataFrame] API` in
`groupBy().applyInPandas()`, enabling batch-by-batch processing of grouped data
for improved memory efficiency and scalability.
#### Key Changes:
1. **New PythonEvalType**: Added `SQL_GROUPED_MAP_PANDAS_ITER_UDF` to
distinguish iterator-based UDFs from standard grouped map UDFs
2. **Type Inference**: Implemented automatic detection of iterator
signatures:
- `Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]`
- `Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]`
3. **Streaming Serialization**: Created `GroupPandasIterUDFSerializer` that
streams results without materializing all DataFrames in memory
4. **Configuration Change**: Updated `FlatMapGroupsInPandasExec` which was
hardcoding `pythonEvalType = 201` instead of extracting it from the UDF
expression (mirrored fix from `FlatMapGroupsInArrowExec`)
### Why are the changes needed?
The existing `applyInPandas()` API loads entire groups into memory as
single DataFrames. For large groups, this can cause OOM errors. The iterator
API allows:
- **Memory Efficiency**: Process data batch-by-batch instead of
materializing entire groups
- **Scalability**: Handle arbitrarily large groups that don't fit in memory
- **Consistency**: Mirrors the existing `applyInArrow()` iterator API design
### Does this PR introduce any user-facing changes?
Yes, this PR adds a new API variant for `applyInPandas()`:
#### Before (existing API, still supported):
```python
def normalize(pdf: pd.DataFrame) -> pd.DataFrame:
return pdf.assign(v=(pdf.v - pdf.v.mean()) / pdf.v.std())
df.groupBy("id").applyInPandas(normalize, schema="id long, v double")
```
#### After (new iterator API):
```python
from typing import Iterator
def normalize(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Process data batch-by-batch
for batch in batches:
yield batch.assign(v=(batch.v - batch.v.mean()) / batch.v.std())
df.groupBy("id").applyInPandas(normalize, schema="id long, v double")
```
#### With Grouping Keys:
```python
from typing import Iterator, Tuple, Any
def sum_by_key(key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
total = 0
for batch in batches:
total += batch['v'].sum()
yield pd.DataFrame({"id": [key[0]], "total": [total]})
df.groupBy("id").applyInPandas(sum_by_key, schema="id long, total double")
```
**Backward Compatibility**: The existing DataFrame-to-DataFrame API is
fully preserved and continues to work without changes.
### How was this patch tested?
- Added `test_apply_in_pandas_iterator_basic` - Basic functionality test
- Added `test_apply_in_pandas_iterator_with_keys` - Test with grouping keys
- Added `test_apply_in_pandas_iterator_batch_slicing` - Pressure test with
10M rows, 20 columns
- Added `test_apply_in_pandas_iterator_with_keys_batch_slicing` - Pressure
test with keys
### Was this patch authored or co-authored using generative AI tooling?
Yes, tests generated by Cursor.
Closes #52716 from Yicong-Huang/SPARK-53614/feat/add-apply-in-pandas.
Lead-authored-by: Yicong-Huang
<[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/api/python/PythonRunner.scala | 3 +
python/pyspark/sql/connect/group.py | 15 +-
python/pyspark/sql/pandas/_typing/__init__.pyi | 3 +
python/pyspark/sql/pandas/functions.py | 19 +
python/pyspark/sql/pandas/functions.pyi | 14 +-
python/pyspark/sql/pandas/group_ops.py | 96 ++++-
python/pyspark/sql/pandas/serializers.py | 82 +++++
python/pyspark/sql/pandas/typehints.py | 95 +++++
.../sql/tests/pandas/test_pandas_grouped_map.py | 395 ++++++++++++++++++++-
python/pyspark/sql/tests/pandas/test_pandas_udf.py | 5 +-
.../sql/tests/pandas/test_pandas_udf_typehints.py | 46 ++-
python/pyspark/sql/udf.py | 9 +-
python/pyspark/util.py | 2 +
python/pyspark/worker.py | 74 ++++
.../sql/connect/planner/SparkConnectPlanner.scala | 3 +-
.../sql/classic/RelationalGroupedDataset.scala | 3 +-
.../python/FlatMapGroupsInPandasExec.scala | 3 +-
17 files changed, 837 insertions(+), 30 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index b3208980da24..66e204fee44b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -67,6 +67,7 @@ private[spark] object PythonEvalType {
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
+ val SQL_GROUPED_MAP_PANDAS_ITER_UDF = 216
// Arrow UDFs
val SQL_SCALAR_ARROW_UDF = 250
@@ -102,6 +103,8 @@ private[spark] object PythonEvalType {
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF =>
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
+ case SQL_GROUPED_MAP_ARROW_ITER_UDF => "SQL_GROUPED_MAP_ARROW_ITER_UDF"
+ case SQL_GROUPED_MAP_PANDAS_ITER_UDF => "SQL_GROUPED_MAP_PANDAS_ITER_UDF"
// Arrow UDFs
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 4bba2b1e0b71..52d280c2c264 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -294,14 +294,25 @@ class GroupedData:
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.pandas.typehints import
infer_group_pandas_eval_type_from_func
- _validate_vectorized_udf(func,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+ # Try to infer the eval type from type hints
+ eval_type = None
+ try:
+ eval_type = infer_group_pandas_eval_type_from_func(func)
+ except Exception:
+ warnings.warn("Cannot infer the eval type from type hints.",
UserWarning)
+
+ if eval_type is None:
+ eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+ _validate_vectorized_udf(func, eval_type)
if isinstance(schema, str):
schema = cast(StructType, self._df._session._parse_ddl(schema))
udf_obj = UserDefinedFunction(
func,
returnType=schema,
- evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ evalType=eval_type,
)
res = DataFrame(
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index cea44921069f..4841b50544fd 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -61,6 +61,7 @@ PandasGroupedMapUDFTransformWithStateInitStateType =
Literal[212]
GroupedMapUDFTransformWithStateType = Literal[213]
GroupedMapUDFTransformWithStateInitStateType = Literal[214]
ArrowGroupedMapIterUDFType = Literal[215]
+PandasGroupedMapIterUDFType = Literal[216]
# Arrow UDFs
ArrowScalarUDFType = Literal[250]
@@ -347,6 +348,8 @@ PandasScalarIterFunction = Union[
PandasGroupedMapFunction = Union[
Callable[[DataFrameLike], DataFrameLike],
Callable[[Any, DataFrameLike], DataFrameLike],
+ Callable[[Iterator[DataFrameLike]], Iterator[DataFrameLike]],
+ Callable[[Any, Iterator[DataFrameLike]], Iterator[DataFrameLike]],
]
PandasGroupedMapFunctionWithState = Callable[
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index 036905766a56..f115a91fce09 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -322,6 +322,7 @@ def arrow_udf(f=None, returnType=None, functionType=None):
pyspark.sql.GroupedData.applyInArrow
pyspark.sql.PandasCogroupedOps.applyInArrow
pyspark.sql.UDFRegistration.register
+ pyspark.sql.GroupedData.applyInPandas
"""
require_minimum_pyarrow_version()
@@ -346,6 +347,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR and GROUPED_AGG type.
+ .. versionchanged:: 4.1.0
+ Supports iterator API in GROUPED_MAP type.
+
Parameters
----------
f : function, optional
@@ -690,6 +694,7 @@ def vectorized_udf(
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
@@ -771,6 +776,7 @@ def _validate_vectorized_udf(f, evalType, kind: str =
"pandas") -> int:
)
elif evalType in [
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
@@ -836,6 +842,19 @@ def _validate_vectorized_udf(f, evalType, kind: str =
"pandas") -> int:
},
)
+ if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF and
len(argspec.args) not in (
+ 1,
+ 2,
+ ):
+ raise PySparkValueError(
+ errorClass="INVALID_PANDAS_UDF",
+ messageParameters={
+ "detail": "the function in groupby.applyInPandas with iterator
API must take "
+ "either one argument (batches: Iterator[pandas.DataFrame]) or
two arguments "
+ "(key, batches: Iterator[pandas.DataFrame]).",
+ },
+ )
+
if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and
len(argspec.args) not in (1, 2):
raise PySparkValueError(
errorClass="INVALID_PANDAS_UDF",
diff --git a/python/pyspark/sql/pandas/functions.pyi
b/python/pyspark/sql/pandas/functions.pyi
index 70ff08679b6b..b9417583b7d4 100644
--- a/python/pyspark/sql/pandas/functions.pyi
+++ b/python/pyspark/sql/pandas/functions.pyi
@@ -29,6 +29,7 @@ from pyspark.sql.pandas._typing import (
PandasGroupedAggFunction,
PandasGroupedAggUDFType,
PandasGroupedMapFunction,
+ PandasGroupedMapIterUDFType,
PandasGroupedMapUDFType,
PandasScalarIterFunction,
PandasScalarIterUDFType,
@@ -145,19 +146,24 @@ def pandas_udf(
def pandas_udf(
f: PandasGroupedMapFunction,
returnType: Union[StructType, str],
- functionType: PandasGroupedMapUDFType,
+ functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> GroupedMapPandasUserDefinedFunction: ...
@overload
def pandas_udf(
- f: Union[StructType, str], returnType: PandasGroupedMapUDFType
+ f: Union[StructType, str],
+ returnType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction],
GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
- *, returnType: Union[StructType, str], functionType:
PandasGroupedMapUDFType
+ *,
+ returnType: Union[StructType, str],
+ functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction],
GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
- f: Union[StructType, str], *, functionType: PandasGroupedMapUDFType
+ f: Union[StructType, str],
+ *,
+ functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
) -> Callable[[PandasGroupedMapFunction],
GroupedMapPandasUserDefinedFunction]: ...
@overload
def pandas_udf(
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 07d78a6ce6d8..1b4aa8798727 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -123,12 +123,13 @@ class PandasGroupedOpsMixin:
Maps each group of the current :class:`DataFrame` using a pandas udf
and returns the result
as a `DataFrame`.
- The function should take a `pandas.DataFrame` and return another
- `pandas.DataFrame`. Alternatively, the user can pass a function that
takes
- a tuple of the grouping key(s) and a `pandas.DataFrame`.
- For each group, all columns are passed together as a `pandas.DataFrame`
- to the user-function and the returned `pandas.DataFrame` are combined
as a
- :class:`DataFrame`.
+ The function can take one of two forms: It can take a
`pandas.DataFrame` and return a
+ `pandas.DataFrame`, or it can take an iterator of `pandas.DataFrame`
and yield
+ `pandas.DataFrame`. Alternatively each form can take a tuple of
grouping keys
+ as the first argument in addition to the input type above.
+ For each group, all columns are passed together as a
`pandas.DataFrame` or iterator of
+ `pandas.DataFrame`, and the returned `pandas.DataFrame` or iterator of
`pandas.DataFrame`
+ are combined as a :class:`DataFrame`.
The `schema` should be a :class:`StructType` describing the schema of
the returned
`pandas.DataFrame`. The column labels of the returned
`pandas.DataFrame` must either match
@@ -141,12 +142,17 @@ class PandasGroupedOpsMixin:
.. versionchanged:: 3.4.0
Support Spark Connect.
+ .. versionchanged:: 4.1.0
+ Added support for an iterator of `pandas.DataFrame` API.
+
Parameters
----------
func : function
- a Python native function that takes a `pandas.DataFrame` and
outputs a
- `pandas.DataFrame`, or that takes one tuple (grouping keys) and a
- `pandas.DataFrame` and outputs a `pandas.DataFrame`.
+ a Python native function that either takes a `pandas.DataFrame`
and outputs a
+ `pandas.DataFrame` or takes an iterator of `pandas.DataFrame` and
yields
+ `pandas.DataFrame`. Additionally, each form can take a tuple of
grouping keys
+ as the first argument, with the `pandas.DataFrame` or iterator of
`pandas.DataFrame`
+ as the second argument.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type
string.
@@ -214,22 +220,84 @@ class PandasGroupedOpsMixin:
| 2| 2| 3.0|
+---+-----------+----+
+ The function can also take and return an iterator of
`pandas.DataFrame` using type
+ hints.
+
+ >>> from typing import Iterator # doctest: +SKIP
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v")) # doctest: +SKIP
+ >>> def filter_func(
+ ... batches: Iterator[pd.DataFrame]
+ ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... for batch in batches:
+ ... # Process and yield each batch independently
+ ... filtered = batch[batch['v'] > 2.0]
+ ... if not filtered.empty:
+ ... yield filtered[['v']]
+ >>> df.groupby("id").applyInPandas(
+ ... filter_func, schema="v double").show() # doctest: +SKIP
+ +----+
+ | v|
+ +----+
+ | 3.0|
+ | 5.0|
+ |10.0|
+ +----+
+
+ Alternatively, the user can pass a function that takes two arguments.
+ In this case, the grouping key(s) will be passed as the first argument
and the data will
+ be passed as the second argument. The grouping key(s) will be passed
as a tuple of numpy
+ data types. The data will still be passed in as an iterator of
`pandas.DataFrame`.
+
+ >>> from typing import Iterator, Tuple, Any # doctest: +SKIP
+ >>> def transform_func(
+ ... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
+ ... for batch in batches:
+ ... # Yield transformed results for each batch
+ ... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
+ ... yield result[['id', 'v_doubled']]
+ >>> df.groupby("id").applyInPandas(
+ ... transform_func, schema="id long, v_doubled double").show() #
doctest: +SKIP
+ +---+----------+
+ | id|v_doubled |
+ +---+----------+
+ | 1| 2.0|
+ | 1| 4.0|
+ | 2| 6.0|
+ | 2| 10.0|
+ | 2| 20.0|
+ +---+----------+
+
Notes
-----
- This function requires a full shuffle. All the data of a group will be
loaded
- into memory, so the user should be aware of the potential OOM risk if
data is skewed
- and certain groups are too large to fit in memory.
+ This function requires a full shuffle. If using the `pandas.DataFrame`
API, all data of a
+ group will be loaded into memory, so the user should be aware of the
potential OOM risk if
+ data is skewed and certain groups are too large to fit in memory, and
can use the
+ iterator of `pandas.DataFrame` API to mitigate this.
See Also
--------
pyspark.sql.functions.pandas_udf
"""
from pyspark.sql import GroupedData
- from pyspark.sql.functions import pandas_udf, PandasUDFType
+ from pyspark.sql.functions import pandas_udf
+ from pyspark.sql.pandas.typehints import
infer_group_pandas_eval_type_from_func
assert isinstance(self, GroupedData)
- udf = pandas_udf(func, returnType=schema,
functionType=PandasUDFType.GROUPED_MAP)
+ # Try to infer the eval type from type hints
+ eval_type = None
+ try:
+ eval_type = infer_group_pandas_eval_type_from_func(func)
+ except Exception as e:
+ warnings.warn(f"Cannot infer the eval type from type hints: {e}",
UserWarning)
+
+ if eval_type is None:
+ eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+ udf = pandas_udf(func, returnType=schema, functionType=eval_type)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index ef8b8c0e6421..56d338fc1371 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1251,6 +1251,88 @@ class
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
return "GroupPandasUDFSerializer"
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+ """
+ Serializer for grouped map Pandas iterator UDFs.
+
+ Loads grouped data as pandas.Series and serializes results from iterator
UDFs.
+ Flattens the (dataframes_generator, arrow_type) tuple by iterating over
the generator.
+ """
+
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ int_to_decimal_coercion_enabled,
+ ):
+ super(GroupPandasIterUDFSerializer, self).__init__(
+ timezone=timezone,
+ safecheck=safecheck,
+ assign_cols_by_name=assign_cols_by_name,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ndarray_as_list=False,
+ arrow_cast=True,
+ input_types=None,
+ int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ )
+
+ def load_stream(self, stream):
+ """
+ Deserialize Grouped ArrowRecordBatches and yield a generator of
pandas.Series lists
+ (one list per batch), allowing the iterator UDF to process data
batch-by-batch.
+ """
+ import pyarrow as pa
+
+ def process_group(batches: "Iterator[pa.RecordBatch]"):
+ # Convert each Arrow batch to pandas Series list on-demand,
yielding one list per batch
+ for batch in batches:
+ series = [
+ self.arrow_to_pandas(c, i)
+ for i, c in
enumerate(pa.Table.from_batches([batch]).itercolumns())
+ ]
+ yield series
+
+ dataframes_in_group = None
+
+ while dataframes_in_group is None or dataframes_in_group > 0:
+ dataframes_in_group = read_int(stream)
+
+ if dataframes_in_group == 1:
+ # Lazily read and convert Arrow batches one at a time from the
stream
+ # This avoids loading all batches into memory for the group
+ batch_iter =
process_group(ArrowStreamSerializer.load_stream(self, stream))
+ yield batch_iter
+ # Make sure the batches are fully iterated before getting the
next group
+ for _ in batch_iter:
+ pass
+
+ elif dataframes_in_group != 0:
+ raise PySparkValueError(
+ errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
+ messageParameters={"dataframes_in_group":
str(dataframes_in_group)},
+ )
+
+ def dump_stream(self, iterator, stream):
+ """
+ Flatten the (dataframes_generator, arrow_type) tuples by iterating
over each generator.
+ This allows the iterator UDF to stream results without materializing
all DataFrames.
+ """
+ # Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type),
(df, arrow_type), ...
+ flattened_iter = (
+ (df, arrow_type) for dataframes_gen, arrow_type in iterator for df
in dataframes_gen
+ )
+
+ # Convert each (df, arrow_type) to the format expected by parent's
dump_stream
+ series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter)
+
+ super(GroupPandasIterUDFSerializer, self).dump_stream(series_iter,
stream)
+
+ def __repr__(self):
+ return "GroupPandasIterUDFSerializer"
+
+
class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
"""
Serializes pyarrow.RecordBatch data with Arrow streaming format.
diff --git a/python/pyspark/sql/pandas/typehints.py
b/python/pyspark/sql/pandas/typehints.py
index c184e0dc5668..18858ab0cf68 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
ArrowGroupedMapIterUDFType,
ArrowGroupedMapUDFType,
ArrowGroupedMapFunction,
+ PandasGroupedMapFunction,
+ PandasGroupedMapUDFType,
+ PandasGroupedMapIterUDFType,
)
@@ -394,6 +397,98 @@ def infer_group_arrow_eval_type_from_func(
return None
+def infer_group_pandas_eval_type(
+ sig: Signature,
+ type_hints: Dict[str, Any],
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+ from pyspark.sql.pandas.functions import PythonEvalType
+
+ require_minimum_pandas_version()
+
+ import pandas as pd
+
+ annotations = {}
+ for param in sig.parameters.values():
+ if param.annotation is not param.empty:
+ annotations[param.name] = type_hints.get(param.name,
param.annotation)
+
+ # Check if all arguments have type hints
+ parameters_sig = [
+ annotations[parameter] for parameter in sig.parameters if parameter in
annotations
+ ]
+ if len(parameters_sig) != len(sig.parameters):
+ raise PySparkValueError(
+ errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+ messageParameters={"target": "all parameters", "sig": str(sig)},
+ )
+
+ # Check if the return has a type hint
+ return_annotation = type_hints.get("return", sig.return_annotation)
+ if sig.empty is return_annotation:
+ raise PySparkValueError(
+ errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+ messageParameters={"target": "the return type", "sig": str(sig)},
+ )
+
+ # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ is_iterator_dataframe = (
+ len(parameters_sig) == 1
+ and check_iterator_annotation( # Iterator
+ parameters_sig[0],
+ parameter_check_func=lambda t: t == pd.DataFrame,
+ )
+ and check_iterator_annotation(
+ return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+ )
+ )
+ # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ is_iterator_dataframe_with_keys = (
+ len(parameters_sig) == 2
+ and check_iterator_annotation( # Iterator
+ parameters_sig[1],
+ parameter_check_func=lambda t: t == pd.DataFrame,
+ )
+ and check_iterator_annotation(
+ return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+ )
+ )
+
+ if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
+ # pd.DataFrame -> pd.DataFrame
+ is_dataframe = (
+ len(parameters_sig) == 1
+ and parameters_sig[0] == pd.DataFrame
+ and return_annotation == pd.DataFrame
+ )
+ # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
+ is_dataframe_with_keys = (
+ len(parameters_sig) == 2
+ and parameters_sig[1] == pd.DataFrame
+ and return_annotation == pd.DataFrame
+ )
+ if is_dataframe or is_dataframe_with_keys:
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+ return None
+
+
+def infer_group_pandas_eval_type_from_func(
+ f: "PandasGroupedMapFunction",
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+ argspec = getfullargspec(f)
+ if len(argspec.annotations) > 0:
+ try:
+ type_hints = get_type_hints(f)
+ except NameError:
+ type_hints = {}
+
+ return infer_group_pandas_eval_type(signature(f), type_hints)
+ else:
+ return None
+
+
def check_tuple_annotation(
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] =
None
) -> bool:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 0e922d072871..991fe67ea1f5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -21,7 +21,7 @@ import logging
from collections import OrderedDict
from decimal import Decimal
-from typing import cast
+from typing import cast, Iterator, Tuple, Any
from pyspark.sql import Row, functions as sf
from pyspark.sql.functions import (
@@ -1024,6 +1024,399 @@ class ApplyInPandasTestsMixin:
],
)
+ def test_apply_in_pandas_iterator_basic(self):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"v": [total]})
+
+ result = df.groupby("id").applyInPandas(sum_func, schema="v
double").orderBy("v").collect()
+ self.assertEqual(len(result), 2)
+ self.assertEqual(result[0][0], 3.0)
+ self.assertEqual(result[1][0], 18.0)
+
+ def test_apply_in_pandas_iterator_with_keys(self):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(
+ key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ) -> Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"id": [key[0]], "v": [total]})
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(sum_func, schema="id long, v double")
+ .orderBy("id")
+ .collect()
+ )
+ self.assertEqual(len(result), 2)
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[0][1], 3.0)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[1][1], 18.0)
+
+ def test_apply_in_pandas_iterator_batch_slicing(self):
+ df = self.spark.range(10000000).select(
+ (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+ )
+ cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+ df = df.withColumns(cols)
+
+ def min_max_v(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Collect all batches to compute min/max across the entire group
+ all_data = []
+ key_val = None
+ for batch in batches:
+ all_data.append(batch)
+ if key_val is None:
+ key_val = batch.key.iloc[0]
+
+ combined = pd.concat(all_data, ignore_index=True)
+ assert len(combined) == 10000000 / 2, len(combined)
+
+ yield pd.DataFrame(
+ {
+ "key": [key_val],
+ "min": [combined.v.min()],
+ "max": [combined.v.max()],
+ }
+ )
+
+ expected = (
+ df.groupby("key")
+ .agg(
+ sf.min("v").alias("min"),
+ sf.max("v").alias("max"),
+ )
+ .sort("key")
+ ).collect()
+
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000,
1048576)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch":
maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result = (
+ df.groupBy("key")
+ .applyInPandas(min_max_v, "key long, min long, max
long")
+ .sort("key")
+ ).collect()
+
+ self.assertEqual(expected, result)
+
+ def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
+ df = self.spark.range(10000000).select(
+ (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+ )
+ cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+ df = df.withColumns(cols)
+
+ def min_max_v(
+ key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ) -> Iterator[pd.DataFrame]:
+ # Collect all batches to compute min/max across the entire group
+ all_data = []
+ for batch in batches:
+ all_data.append(batch)
+
+ combined = pd.concat(all_data, ignore_index=True)
+ assert len(combined) == 10000000 / 2, len(combined)
+
+ yield pd.DataFrame(
+ {
+ "key": [key[0]],
+ "min": [combined.v.min()],
+ "max": [combined.v.max()],
+ }
+ )
+
+ expected = (
+ df.groupby("key").agg(sf.min("v").alias("min"),
sf.max("v").alias("max")).sort("key")
+ ).collect()
+
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000,
1048576)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch":
maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result = (
+ df.groupBy("key")
+ .applyInPandas(min_max_v, "key long, min long, max
long")
+ .sort("key")
+ ).collect()
+
+ self.assertEqual(expected, result)
+
+ def test_apply_in_pandas_iterator_multiple_output_batches(self):
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (1, 3.0), (2, 4.0), (2, 5.0), (2, 6.0)],
("id", "v")
+ )
+
+ def split_and_yield(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Yield multiple output batches for each input batch
+ for batch in batches:
+ for _, row in batch.iterrows():
+ # Yield each row as a separate batch to test multiple
yields
+ yield pd.DataFrame(
+ {"id": [row["id"]], "v": [row["v"]], "v_doubled":
[row["v"] * 2]}
+ )
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(split_and_yield, schema="id long, v double,
v_doubled double")
+ .orderBy("id", "v")
+ .collect()
+ )
+
+ expected = [
+ Row(id=1, v=1.0, v_doubled=2.0),
+ Row(id=1, v=2.0, v_doubled=4.0),
+ Row(id=1, v=3.0, v_doubled=6.0),
+ Row(id=2, v=4.0, v_doubled=8.0),
+ Row(id=2, v=5.0, v_doubled=10.0),
+ Row(id=2, v=6.0, v_doubled=12.0),
+ ]
+ self.assertEqual(result, expected)
+
+ def test_apply_in_pandas_iterator_filter_multiple_batches(self):
+ df = self.spark.createDataFrame(
+ [(1, i * 1.0) for i in range(20)] + [(2, i * 1.0) for i in
range(20)], ("id", "v")
+ )
+
+ def filter_and_yield(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Yield filtered results from each batch
+ for batch in batches:
+ # Filter even values and yield
+ even_batch = batch[batch["v"] % 2 == 0]
+ if not even_batch.empty:
+ yield even_batch
+
+ # Filter odd values and yield separately
+ odd_batch = batch[batch["v"] % 2 == 1]
+ if not odd_batch.empty:
+ yield odd_batch
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(filter_and_yield, schema="id long, v double")
+ .orderBy("id", "v")
+ .collect()
+ )
+
+ # Verify all 40 rows are present (20 per group)
+ self.assertEqual(len(result), 40)
+
+ # Verify group 1 has all values 0-19
+ group1 = [row for row in result if row[0] == 1]
+ self.assertEqual(len(group1), 20)
+ self.assertEqual([row[1] for row in group1], [float(i) for i in
range(20)])
+
+ # Verify group 2 has all values 0-19
+ group2 = [row for row in result if row[0] == 2]
+ self.assertEqual(len(group2), 20)
+ self.assertEqual([row[1] for row in group2], [float(i) for i in
range(20)])
+
+ def test_apply_in_pandas_iterator_with_keys_multiple_batches(self):
+ df = self.spark.createDataFrame(
+ [
+ (1, "a", 1.0),
+ (1, "b", 2.0),
+ (1, "c", 3.0),
+ (2, "d", 4.0),
+ (2, "e", 5.0),
+ (2, "f", 6.0),
+ ],
+ ("id", "name", "v"),
+ )
+
+ def process_with_key(
+ key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+ ) -> Iterator[pd.DataFrame]:
+ # Yield multiple processed batches, including the key in each
output
+ for batch in batches:
+ # Split batch and yield multiple output batches
+ for chunk_size in [1, 2]:
+ for i in range(0, len(batch), chunk_size):
+ chunk = batch.iloc[i : i + chunk_size]
+ if not chunk.empty:
+ result = chunk.assign(id=key[0],
total=chunk["v"].sum())
+ yield result[["id", "name", "total"]]
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(process_with_key, schema="id long, name string,
total double")
+ .orderBy("id", "name")
+ .collect()
+ )
+
+ # Verify we get results (may have duplicates due to splitting)
+ self.assertTrue(len(result) > 6)
+
+ # Verify all original names are present
+ names = [row[1] for row in result]
+ self.assertIn("a", names)
+ self.assertIn("b", names)
+ self.assertIn("c", names)
+ self.assertIn("d", names)
+ self.assertIn("e", names)
+ self.assertIn("f", names)
+
+ # Verify keys are correct
+ for row in result:
+ if row[1] in ["a", "b", "c"]:
+ self.assertEqual(row[0], 1)
+ else:
+ self.assertEqual(row[0], 2)
+
+ def test_apply_in_pandas_iterator_process_multiple_input_batches(self):
+ import builtins
+
+ # Create large dataset to trigger batch slicing
+ df = self.spark.range(100000).select(
+ (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+ )
+
+ def process_batches_progressively(
+ batches: Iterator[pd.DataFrame],
+ ) -> Iterator[pd.DataFrame]:
+ # Process each input batch and yield output immediately
+ batch_count = 0
+ for batch in batches:
+ batch_count += 1
+ # Yield a summary for each input batch processed
+ yield pd.DataFrame(
+ {
+ "key": [batch.key.iloc[0]],
+ "batch_num": [batch_count],
+ "count": [len(batch)],
+ "sum": [batch.v.sum()],
+ }
+ )
+
+ # Use small batch size to force multiple input batches
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": 10000,
+ }
+ ):
+ result = (
+ df.groupBy("key")
+ .applyInPandas(
+ process_batches_progressively,
+ schema="key long, batch_num long, count long, sum long",
+ )
+ .orderBy("key", "batch_num")
+ .collect()
+ )
+
+ # Verify we got multiple batches per group (100000/2 = 50000 rows per
group)
+ # With maxRecordsPerBatch=10000, should get 5 batches per group
+ group_0_batches = [r for r in result if r[0] == 0]
+ group_1_batches = [r for r in result if r[0] == 1]
+
+ # Verify multiple batches were processed
+ self.assertGreater(len(group_0_batches), 1)
+ self.assertGreater(len(group_1_batches), 1)
+
+ # Verify the sum across all batches equals expected total (using
Python's built-in sum)
+ group_0_sum = builtins.sum(r[3] for r in group_0_batches)
+ group_1_sum = builtins.sum(r[3] for r in group_1_batches)
+
+ # Expected: sum of even numbers 0,2,4,...,99998
+ expected_even_sum = builtins.sum(range(0, 100000, 2))
+ expected_odd_sum = builtins.sum(range(1, 100000, 2))
+
+ self.assertEqual(group_0_sum, expected_even_sum)
+ self.assertEqual(group_1_sum, expected_odd_sum)
+
+ def test_apply_in_pandas_iterator_streaming_aggregation(self):
+ # Create dataset with multiple batches per group
+ df = self.spark.range(50000).select(
+ (sf.col("id") % 3).alias("key"),
+ (sf.col("id") % 100).alias("category"),
+ sf.col("id").alias("value"),
+ )
+
+ def streaming_aggregate(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Maintain running aggregates and yield intermediate results
+ running_sum = 0
+ running_count = 0
+
+ for batch in batches:
+ # Update running aggregates
+ running_sum += batch.value.sum()
+ running_count += len(batch)
+
+ # Yield current stats after processing each batch
+ yield pd.DataFrame(
+ {
+ "key": [batch.key.iloc[0]],
+ "running_count": [running_count],
+ "running_avg": [running_sum / running_count],
+ }
+ )
+
+ # Force multiple batches with small batch size
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch": 5000,
+ }
+ ):
+ result = (
+ df.groupBy("key")
+ .applyInPandas(
+ streaming_aggregate, schema="key long, running_count long,
running_avg double"
+ )
+ .collect()
+ )
+
+ # Verify we got multiple rows per group (one per input batch)
+ for key_val in [0, 1, 2]:
+ key_results = [r for r in result if r[0] == key_val]
+ # Should have multiple batches
+ # (50000/3 ≈ 16667 rows per group, with 5000 per batch = ~4
batches)
+ self.assertGreater(len(key_results), 1, f"Expected multiple
batches for key {key_val}")
+
+ # Verify running_count increases monotonically
+ counts = [r[1] for r in key_results]
+ for i in range(1, len(counts)):
+ self.assertGreater(
+ counts[i], counts[i - 1], "Running count should increase
with each batch"
+ )
+
+ def test_apply_in_pandas_iterator_partial_iteration(self):
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch":
2}):
+
+ def func(batches: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ # Only consume the first batch from the iterator
+ first = next(batches)
+ yield pd.DataFrame({"value": first["id"] % 4})
+
+ df = self.spark.range(20)
+ grouped_df = df.groupBy((col("id") % 4).cast("int"))
+
+ # Should get two records for each group (first batch only)
+ expected = [Row(value=x) for x in [0, 0, 1, 1, 2, 2, 3, 3]]
+
+ actual = grouped_df.applyInPandas(func, "value long").collect()
+ self.assertEqual(actual, expected)
+
class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 23fceb746114..017698f318d5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -183,8 +183,9 @@ class PandasUDFTestsMixin:
exception=pe.exception,
errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
messageParameters={
- "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF "
- "or SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
+ "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
+ "SQL_GROUPED_MAP_PANDAS_ITER_UDF or "
+ "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
"return_type": "DoubleType()",
},
)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index a436b71c123d..09de2a2e3198 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -26,9 +26,10 @@ from pyspark.testing.sqlutils import (
pandas_requirement_message,
pyarrow_requirement_message,
)
-from pyspark.sql.pandas.typehints import infer_eval_type
+from pyspark.sql.pandas.typehints import infer_eval_type,
infer_group_pandas_eval_type
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
from pyspark.sql import Row
+from pyspark.util import PythonEvalType
if have_pandas:
import pandas as pd
@@ -186,6 +187,49 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
PandasUDFType.GROUPED_AGG,
)
+ def test_type_annotation_group_map(self):
+ # pd.DataFrame -> pd.DataFrame
+ def func(col: pd.DataFrame) -> pd.DataFrame:
+ pass
+
+ self.assertEqual(
+ infer_group_pandas_eval_type(signature(func),
get_type_hints(func)),
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ )
+
+ # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
+ def func(key: Tuple, col: pd.DataFrame) -> pd.DataFrame:
+ pass
+
+ self.assertEqual(
+ infer_group_pandas_eval_type(signature(func),
get_type_hints(func)),
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ )
+
+ # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ def func(col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
+ pass
+
+ self.assertEqual(
+ infer_group_pandas_eval_type(signature(func),
get_type_hints(func)),
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
+ )
+
+ # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+ def func(key: Tuple, col: Iterator[pd.DataFrame]) ->
Iterator[pd.DataFrame]:
+ pass
+
+ self.assertEqual(
+ infer_group_pandas_eval_type(signature(func),
get_type_hints(func)),
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
+ )
+
+ # Should return None for unsupported signatures
+ def func(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ pass
+
+ self.assertEqual(infer_group_pandas_eval_type(signature(func),
get_type_hints(func)), None)
+
def test_type_annotation_negative(self):
def func(col: str) -> pd.Series:
pass
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b28bccb04bc7..37aa30cc279f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -238,6 +238,7 @@ class UserDefinedFunction:
)
elif (
evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+ or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
):
if isinstance(returnType, StructType):
@@ -256,6 +257,7 @@ class UserDefinedFunction:
errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
messageParameters={
"eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
+ "SQL_GROUPED_MAP_PANDAS_ITER_UDF or "
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
"return_type": str(returnType),
},
@@ -282,7 +284,10 @@ class UserDefinedFunction:
"return_type": str(returnType),
},
)
- elif evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+ elif (
+ evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+ or evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+ ):
if isinstance(returnType, StructType):
try:
to_arrow_type(returnType)
@@ -298,7 +303,7 @@ class UserDefinedFunction:
raise PySparkTypeError(
errorClass="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
messageParameters={
- "eval_type": "SQL_GROUPED_MAP_ARROW_UDF",
+ "eval_type": "SQL_GROUPED_MAP_ARROW_UDF or
SQL_GROUPED_MAP_ARROW_ITER_UDF",
"return_type": str(returnType),
},
)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index f94fc73b6435..f633ed699ee2 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -62,6 +62,7 @@ if typing.TYPE_CHECKING:
ArrowGroupedMapUDFType,
ArrowGroupedMapIterUDFType,
ArrowCogroupedMapUDFType,
+ PandasGroupedMapIterUDFType,
PandasGroupedMapUDFTransformWithStateType,
PandasGroupedMapUDFTransformWithStateInitStateType,
GroupedMapUDFTransformWithStateType,
@@ -652,6 +653,7 @@ class PythonEvalType:
214
)
SQL_GROUPED_MAP_ARROW_ITER_UDF: "ArrowGroupedMapIterUDFType" = 215
+ SQL_GROUPED_MAP_PANDAS_ITER_UDF: "PandasGroupedMapIterUDFType" = 216
# Arrow UDFs
SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b232b30c5420..09c6a40a33db 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -60,6 +60,7 @@ from pyspark.sql.pandas.serializers import (
CogroupPandasUDFSerializer,
ArrowStreamUDFSerializer,
ApplyInPandasWithStateSerializer,
+ GroupPandasIterUDFSerializer,
GroupPandasUDFSerializer,
TransformWithStateInPandasSerializer,
TransformWithStateInPandasInitStateSerializer,
@@ -748,6 +749,39 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec,
runner_conf):
return lambda k, v: [(wrapped(k, v), arrow_return_type)]
+def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf):
+ _use_large_var_types = use_large_var_types(runner_conf)
+ _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
+ def wrapped(key_series_list, value_series_gen):
+ import pandas as pd
+
+ # value_series_gen is a generator that yields multiple lists of Series
(one per batch)
+ # Convert each list of Series into a DataFrame
+ def dataframe_iter():
+ for value_series in value_series_gen:
+ yield pd.concat(value_series, axis=1)
+
+ # Extract key from the first batch
+ if len(argspec.args) == 1:
+ result = f(dataframe_iter())
+ elif len(argspec.args) == 2:
+ # key_series_list is a list of Series for the key columns from the
first batch
+ key = tuple(s[0] for s in key_series_list)
+ result = f(key, dataframe_iter())
+
+ def verify_element(df):
+ verify_pandas_result(
+ df, return_type, _assign_cols_by_name,
truncate_return_schema=False
+ )
+ return df
+
+ yield from map(verify_element, result)
+
+ arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+ return lambda k, v: (wrapped(k, v), arrow_return_type)
+
+
def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
result_iter = f(stateful_processor_api_client, mode, key,
value_series_gen)
@@ -1262,6 +1296,11 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index, profil
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+ argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
+ return args_offsets, wrap_grouped_map_pandas_iter_udf(
+ func, return_type, argspec, runner_conf
+ )
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
argspec = inspect.getfullargspec(chained_func) # signature was lost
when wrapping it
return args_offsets, wrap_grouped_map_arrow_udf(func, return_type,
argspec, runner_conf)
@@ -2565,6 +2604,7 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
@@ -2635,6 +2675,10 @@ def read_udfs(pickleSer, infile, eval_type):
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+ ser = GroupPandasIterUDFSerializer(
+ timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
+ )
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
ser = CogroupArrowUDFSerializer(_assign_cols_by_name)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
@@ -2894,6 +2938,36 @@ def read_udfs(pickleSer, infile, eval_type):
vals = [a[o] for o in parsed_offsets[0][1]]
return f(keys, vals)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
+ arg_offsets, f = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=0,
profiler=profiler
+ )
+ parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+ # Create mapper similar to Arrow iterator:
+ # `a` is an iterator of Series lists (one list per batch, containing
all columns)
+ # Materialize first batch to get keys, then create generator for value
batches
+ def mapper(a):
+ import itertools
+
+ series_iter = iter(a)
+ # Need to materialize the first series list to get the keys
+ first_series_list = next(series_iter)
+
+ keys = [first_series_list[o] for o in parsed_offsets[0][0]]
+ value_series_gen = (
+ [series_list[o] for o in parsed_offsets[0][1]]
+ for series_list in itertools.chain((first_series_list,),
series_iter)
+ )
+
+ return f(keys, value_series_gen)
+
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
# We assume there is only one UDF here because grouped map doesn't
# support combining multiple UDFs.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 41efe8db842f..8f8e6261066f 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -690,7 +690,8 @@ class SparkConnectPlanner(
.groupBy(cols: _*)
pythonUdf.evalType match {
- case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
+ case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF |
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF =>
group.flatMapGroupsInPandas(Column(pythonUdf)).logicalPlan
case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF |
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
index 674c206c96a9..bd7b3348b9f0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
@@ -278,7 +278,8 @@ class RelationalGroupedDataset protected[sql](
*/
private[sql] def flatMapGroupsInPandas(column: Column): DataFrame = {
val expr = column.expr.asInstanceOf[PythonUDF]
- require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF ||
+ expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
"Must pass a grouped map pandas udf")
require(expr.dataType.isInstanceOf[StructType],
s"The returnType of the udf must be a ${StructType.simpleString}")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 82c2f3fab200..765181d7e331 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
@@ -46,7 +45,7 @@ case class FlatMapGroupsInPandasExec(
child: SparkPlan)
extends FlatMapGroupsInBatchExec {
- protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+ protected val pythonEvalType: Int = func.asInstanceOf[PythonUDF].evalType
override protected def withNewChildInternal(newChild: SparkPlan):
FlatMapGroupsInPandasExec =
copy(child = newChild)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]