This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4202f239c45a [SPARK-53614][PYTHON] Add `Iterator[pandas.DataFrame]` 
support to `applyInPandas`
4202f239c45a is described below

commit 4202f239c45a290e340bcd505de849d876f992fa
Author: Yicong-Huang <[email protected]>
AuthorDate: Fri Oct 31 14:33:05 2025 +0800

    [SPARK-53614][PYTHON] Add `Iterator[pandas.DataFrame]` support to 
`applyInPandas`
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for the `Iterator[pandas.DataFrame] API` in 
`groupBy().applyInPandas()`, enabling batch-by-batch processing of grouped data 
for improved memory efficiency and scalability.
    
    #### Key Changes:
    
    1. **New PythonEvalType**: Added `SQL_GROUPED_MAP_PANDAS_ITER_UDF` to 
distinguish iterator-based UDFs from standard grouped map UDFs
    
    2. **Type Inference**: Implemented automatic detection of iterator 
signatures:
       - `Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]`
       - `Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]`
    
    3. **Streaming Serialization**: Created `GroupPandasIterUDFSerializer` that 
streams results without materializing all DataFrames in memory
    
    4. **Configuration Change**: Updated `FlatMapGroupsInPandasExec` which was 
hardcoding `pythonEvalType = 201` instead of extracting it from the UDF 
expression (mirrored fix from `FlatMapGroupsInArrowExec`)
    
    ### Why are the changes needed?
    
    The existing `applyInPandas()` API loads entire groups into memory as 
single DataFrames. For large groups, this can cause OOM errors. The iterator 
API allows:
    
    - **Memory Efficiency**: Process data batch-by-batch instead of 
materializing entire groups
    - **Scalability**: Handle arbitrarily large groups that don't fit in memory
    - **Consistency**: Mirrors the existing `applyInArrow()` iterator API design
    
    ### Does this PR introduce any user-facing changes?
    
    Yes, this PR adds a new API variant for `applyInPandas()`:
    
    #### Before (existing API, still supported):
    ```python
    def normalize(pdf: pd.DataFrame) -> pd.DataFrame:
        return pdf.assign(v=(pdf.v - pdf.v.mean()) / pdf.v.std())
    
    df.groupBy("id").applyInPandas(normalize, schema="id long, v double")
    ```
    
    #### After (new iterator API):
    ```python
    from typing import Iterator
    
    def normalize(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
        # Process data batch-by-batch
        for batch in batches:
            yield batch.assign(v=(batch.v - batch.v.mean()) / batch.v.std())
    
    df.groupBy("id").applyInPandas(normalize, schema="id long, v double")
    ```
    
    #### With Grouping Keys:
    ```python
    from typing import Iterator, Tuple, Any
    
    def sum_by_key(key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
        total = 0
        for batch in batches:
            total += batch['v'].sum()
        yield pd.DataFrame({"id": [key[0]], "total": [total]})
    
    df.groupBy("id").applyInPandas(sum_by_key, schema="id long, total double")
    ```
    
    **Backward Compatibility**: The existing DataFrame-to-DataFrame API is 
fully preserved and continues to work without changes.
    
    ### How was this patch tested?
    
    - Added `test_apply_in_pandas_iterator_basic` - Basic functionality test
    - Added `test_apply_in_pandas_iterator_with_keys` - Test with grouping keys
    - Added `test_apply_in_pandas_iterator_batch_slicing` - Pressure test with 
10M rows, 20 columns
    - Added `test_apply_in_pandas_iterator_with_keys_batch_slicing` - Pressure 
test with keys
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, tests generated by Cursor.
    
    Closes #52716 from Yicong-Huang/SPARK-53614/feat/add-apply-in-pandas.
    
    Lead-authored-by: Yicong-Huang 
<[email protected]>
    Co-authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/api/python/PythonRunner.scala |   3 +
 python/pyspark/sql/connect/group.py                |  15 +-
 python/pyspark/sql/pandas/_typing/__init__.pyi     |   3 +
 python/pyspark/sql/pandas/functions.py             |  19 +
 python/pyspark/sql/pandas/functions.pyi            |  14 +-
 python/pyspark/sql/pandas/group_ops.py             |  96 ++++-
 python/pyspark/sql/pandas/serializers.py           |  82 +++++
 python/pyspark/sql/pandas/typehints.py             |  95 +++++
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 395 ++++++++++++++++++++-
 python/pyspark/sql/tests/pandas/test_pandas_udf.py |   5 +-
 .../sql/tests/pandas/test_pandas_udf_typehints.py  |  46 ++-
 python/pyspark/sql/udf.py                          |   9 +-
 python/pyspark/util.py                             |   2 +
 python/pyspark/worker.py                           |  74 ++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |   3 +-
 .../sql/classic/RelationalGroupedDataset.scala     |   3 +-
 .../python/FlatMapGroupsInPandasExec.scala         |   3 +-
 17 files changed, 837 insertions(+), 30 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index b3208980da24..66e204fee44b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -67,6 +67,7 @@ private[spark] object PythonEvalType {
   val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
   val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
   val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
+  val SQL_GROUPED_MAP_PANDAS_ITER_UDF = 216
 
   // Arrow UDFs
   val SQL_SCALAR_ARROW_UDF = 250
@@ -102,6 +103,8 @@ private[spark] object PythonEvalType {
     case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => 
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
     case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
       "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
+    case SQL_GROUPED_MAP_ARROW_ITER_UDF => "SQL_GROUPED_MAP_ARROW_ITER_UDF"
+    case SQL_GROUPED_MAP_PANDAS_ITER_UDF => "SQL_GROUPED_MAP_PANDAS_ITER_UDF"
 
     // Arrow UDFs
     case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 4bba2b1e0b71..52d280c2c264 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -294,14 +294,25 @@ class GroupedData:
     ) -> "DataFrame":
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
+        from pyspark.sql.pandas.typehints import 
infer_group_pandas_eval_type_from_func
 
-        _validate_vectorized_udf(func, 
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
+        # Try to infer the eval type from type hints
+        eval_type = None
+        try:
+            eval_type = infer_group_pandas_eval_type_from_func(func)
+        except Exception:
+            warnings.warn("Cannot infer the eval type from type hints.", 
UserWarning)
+
+        if eval_type is None:
+            eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+        _validate_vectorized_udf(func, eval_type)
         if isinstance(schema, str):
             schema = cast(StructType, self._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
-            evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+            evalType=eval_type,
         )
 
         res = DataFrame(
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index cea44921069f..4841b50544fd 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -61,6 +61,7 @@ PandasGroupedMapUDFTransformWithStateInitStateType = 
Literal[212]
 GroupedMapUDFTransformWithStateType = Literal[213]
 GroupedMapUDFTransformWithStateInitStateType = Literal[214]
 ArrowGroupedMapIterUDFType = Literal[215]
+PandasGroupedMapIterUDFType = Literal[216]
 
 # Arrow UDFs
 ArrowScalarUDFType = Literal[250]
@@ -347,6 +348,8 @@ PandasScalarIterFunction = Union[
 PandasGroupedMapFunction = Union[
     Callable[[DataFrameLike], DataFrameLike],
     Callable[[Any, DataFrameLike], DataFrameLike],
+    Callable[[Iterator[DataFrameLike]], Iterator[DataFrameLike]],
+    Callable[[Any, Iterator[DataFrameLike]], Iterator[DataFrameLike]],
 ]
 
 PandasGroupedMapFunctionWithState = Callable[
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 036905766a56..f115a91fce09 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -322,6 +322,7 @@ def arrow_udf(f=None, returnType=None, functionType=None):
     pyspark.sql.GroupedData.applyInArrow
     pyspark.sql.PandasCogroupedOps.applyInArrow
     pyspark.sql.UDFRegistration.register
+    pyspark.sql.GroupedData.applyInPandas
     """
     require_minimum_pyarrow_version()
 
@@ -346,6 +347,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
     .. versionchanged:: 4.0.0
         Supports keyword-arguments in SCALAR and GROUPED_AGG type.
 
+    .. versionchanged:: 4.1.0
+        Supports iterator API in GROUPED_MAP type.
+
     Parameters
     ----------
     f : function, optional
@@ -690,6 +694,7 @@ def vectorized_udf(
         PythonEvalType.SQL_SCALAR_PANDAS_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
@@ -771,6 +776,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = 
"pandas") -> int:
         )
     elif evalType in [
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
@@ -836,6 +842,19 @@ def _validate_vectorized_udf(f, evalType, kind: str = 
"pandas") -> int:
             },
         )
 
+    if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF and 
len(argspec.args) not in (
+        1,
+        2,
+    ):
+        raise PySparkValueError(
+            errorClass="INVALID_PANDAS_UDF",
+            messageParameters={
+                "detail": "the function in groupby.applyInPandas with iterator 
API must take "
+                "either one argument (batches: Iterator[pandas.DataFrame]) or 
two arguments "
+                "(key, batches: Iterator[pandas.DataFrame]).",
+            },
+        )
+
     if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and 
len(argspec.args) not in (1, 2):
         raise PySparkValueError(
             errorClass="INVALID_PANDAS_UDF",
diff --git a/python/pyspark/sql/pandas/functions.pyi 
b/python/pyspark/sql/pandas/functions.pyi
index 70ff08679b6b..b9417583b7d4 100644
--- a/python/pyspark/sql/pandas/functions.pyi
+++ b/python/pyspark/sql/pandas/functions.pyi
@@ -29,6 +29,7 @@ from pyspark.sql.pandas._typing import (
     PandasGroupedAggFunction,
     PandasGroupedAggUDFType,
     PandasGroupedMapFunction,
+    PandasGroupedMapIterUDFType,
     PandasGroupedMapUDFType,
     PandasScalarIterFunction,
     PandasScalarIterUDFType,
@@ -145,19 +146,24 @@ def pandas_udf(
 def pandas_udf(
     f: PandasGroupedMapFunction,
     returnType: Union[StructType, str],
-    functionType: PandasGroupedMapUDFType,
+    functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
 ) -> GroupedMapPandasUserDefinedFunction: ...
 @overload
 def pandas_udf(
-    f: Union[StructType, str], returnType: PandasGroupedMapUDFType
+    f: Union[StructType, str],
+    returnType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
 ) -> Callable[[PandasGroupedMapFunction], 
GroupedMapPandasUserDefinedFunction]: ...
 @overload
 def pandas_udf(
-    *, returnType: Union[StructType, str], functionType: 
PandasGroupedMapUDFType
+    *,
+    returnType: Union[StructType, str],
+    functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
 ) -> Callable[[PandasGroupedMapFunction], 
GroupedMapPandasUserDefinedFunction]: ...
 @overload
 def pandas_udf(
-    f: Union[StructType, str], *, functionType: PandasGroupedMapUDFType
+    f: Union[StructType, str],
+    *,
+    functionType: Union[PandasGroupedMapUDFType, PandasGroupedMapIterUDFType],
 ) -> Callable[[PandasGroupedMapFunction], 
GroupedMapPandasUserDefinedFunction]: ...
 @overload
 def pandas_udf(
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 07d78a6ce6d8..1b4aa8798727 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -123,12 +123,13 @@ class PandasGroupedOpsMixin:
         Maps each group of the current :class:`DataFrame` using a pandas udf 
and returns the result
         as a `DataFrame`.
 
-        The function should take a `pandas.DataFrame` and return another
-        `pandas.DataFrame`. Alternatively, the user can pass a function that 
takes
-        a tuple of the grouping key(s) and a `pandas.DataFrame`.
-        For each group, all columns are passed together as a `pandas.DataFrame`
-        to the user-function and the returned `pandas.DataFrame` are combined 
as a
-        :class:`DataFrame`.
+        The function can take one of two forms: It can take a 
`pandas.DataFrame` and return a
+        `pandas.DataFrame`, or it can take an iterator of `pandas.DataFrame` 
and yield
+        `pandas.DataFrame`. Alternatively each form can take a tuple of 
grouping keys
+        as the first argument in addition to the input type above.
+        For each group, all columns are passed together as a 
`pandas.DataFrame` or iterator of
+        `pandas.DataFrame`, and the returned `pandas.DataFrame` or iterator of 
`pandas.DataFrame`
+        are combined as a :class:`DataFrame`.
 
         The `schema` should be a :class:`StructType` describing the schema of 
the returned
         `pandas.DataFrame`. The column labels of the returned 
`pandas.DataFrame` must either match
@@ -141,12 +142,17 @@ class PandasGroupedOpsMixin:
         .. versionchanged:: 3.4.0
             Support Spark Connect.
 
+        .. versionchanged:: 4.1.0
+            Added support for an iterator of `pandas.DataFrame` API.
+
         Parameters
         ----------
         func : function
-            a Python native function that takes a `pandas.DataFrame` and 
outputs a
-            `pandas.DataFrame`, or that takes one tuple (grouping keys) and a
-            `pandas.DataFrame` and outputs a `pandas.DataFrame`.
+            a Python native function that either takes a `pandas.DataFrame` 
and outputs a
+            `pandas.DataFrame` or takes an iterator of `pandas.DataFrame` and 
yields
+            `pandas.DataFrame`. Additionally, each form can take a tuple of 
grouping keys
+            as the first argument, with the `pandas.DataFrame` or iterator of 
`pandas.DataFrame`
+            as the second argument.
         schema : :class:`pyspark.sql.types.DataType` or str
             the return type of the `func` in PySpark. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
@@ -214,22 +220,84 @@ class PandasGroupedOpsMixin:
         |  2|          2| 3.0|
         +---+-----------+----+
 
+        The function can also take and return an iterator of 
`pandas.DataFrame` using type
+        hints.
+
+        >>> from typing import Iterator  # doctest: +SKIP
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def filter_func(
+        ...     batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Process and yield each batch independently
+        ...         filtered = batch[batch['v'] > 2.0]
+        ...         if not filtered.empty:
+        ...             yield filtered[['v']]
+        >>> df.groupby("id").applyInPandas(
+        ...     filter_func, schema="v double").show()  # doctest: +SKIP
+        +----+
+        |   v|
+        +----+
+        | 3.0|
+        | 5.0|
+        |10.0|
+        +----+
+
+        Alternatively, the user can pass a function that takes two arguments.
+        In this case, the grouping key(s) will be passed as the first argument 
and the data will
+        be passed as the second argument. The grouping key(s) will be passed 
as a tuple of numpy
+        data types. The data will still be passed in as an iterator of 
`pandas.DataFrame`.
+
+        >>> from typing import Iterator, Tuple, Any  # doctest: +SKIP
+        >>> def transform_func(
+        ...     key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ... ) -> Iterator[pd.DataFrame]:  # doctest: +SKIP
+        ...     for batch in batches:
+        ...         # Yield transformed results for each batch
+        ...         result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
+        ...         yield result[['id', 'v_doubled']]
+        >>> df.groupby("id").applyInPandas(
+        ...     transform_func, schema="id long, v_doubled double").show()  # 
doctest: +SKIP
+        +---+----------+
+        | id|v_doubled |
+        +---+----------+
+        |  1|       2.0|
+        |  1|       4.0|
+        |  2|       6.0|
+        |  2|      10.0|
+        |  2|      20.0|
+        +---+----------+
+
         Notes
         -----
-        This function requires a full shuffle. All the data of a group will be 
loaded
-        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
-        and certain groups are too large to fit in memory.
+        This function requires a full shuffle. If using the `pandas.DataFrame` 
API, all data of a
+        group will be loaded into memory, so the user should be aware of the 
potential OOM risk if
+        data is skewed and certain groups are too large to fit in memory, and 
can use the
+        iterator of `pandas.DataFrame` API to mitigate this.
 
         See Also
         --------
         pyspark.sql.functions.pandas_udf
         """
         from pyspark.sql import GroupedData
-        from pyspark.sql.functions import pandas_udf, PandasUDFType
+        from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.pandas.typehints import 
infer_group_pandas_eval_type_from_func
 
         assert isinstance(self, GroupedData)
 
-        udf = pandas_udf(func, returnType=schema, 
functionType=PandasUDFType.GROUPED_MAP)
+        # Try to infer the eval type from type hints
+        eval_type = None
+        try:
+            eval_type = infer_group_pandas_eval_type_from_func(func)
+        except Exception as e:
+            warnings.warn(f"Cannot infer the eval type from type hints: {e}", 
UserWarning)
+
+        if eval_type is None:
+            eval_type = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+        udf = pandas_udf(func, returnType=schema, functionType=eval_type)
         df = self._df
         udf_column = udf(*[df[col] for col in df.columns])
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc)
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index ef8b8c0e6421..56d338fc1371 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1251,6 +1251,88 @@ class 
GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
         return "GroupPandasUDFSerializer"
 
 
+class GroupPandasIterUDFSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer for grouped map Pandas iterator UDFs.
+
+    Loads grouped data as pandas.Series and serializes results from iterator 
UDFs.
+    Flattens the (dataframes_generator, arrow_type) tuple by iterating over 
the generator.
+    """
+
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        int_to_decimal_coercion_enabled,
+    ):
+        super(GroupPandasIterUDFSerializer, self).__init__(
+            timezone=timezone,
+            safecheck=safecheck,
+            assign_cols_by_name=assign_cols_by_name,
+            df_for_struct=False,
+            struct_in_pandas="dict",
+            ndarray_as_list=False,
+            arrow_cast=True,
+            input_types=None,
+            int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+        )
+
+    def load_stream(self, stream):
+        """
+        Deserialize Grouped ArrowRecordBatches and yield a generator of 
pandas.Series lists
+        (one list per batch), allowing the iterator UDF to process data 
batch-by-batch.
+        """
+        import pyarrow as pa
+
+        def process_group(batches: "Iterator[pa.RecordBatch]"):
+            # Convert each Arrow batch to pandas Series list on-demand, 
yielding one list per batch
+            for batch in batches:
+                series = [
+                    self.arrow_to_pandas(c, i)
+                    for i, c in 
enumerate(pa.Table.from_batches([batch]).itercolumns())
+                ]
+                yield series
+
+        dataframes_in_group = None
+
+        while dataframes_in_group is None or dataframes_in_group > 0:
+            dataframes_in_group = read_int(stream)
+
+            if dataframes_in_group == 1:
+                # Lazily read and convert Arrow batches one at a time from the 
stream
+                # This avoids loading all batches into memory for the group
+                batch_iter = 
process_group(ArrowStreamSerializer.load_stream(self, stream))
+                yield batch_iter
+                # Make sure the batches are fully iterated before getting the 
next group
+                for _ in batch_iter:
+                    pass
+
+            elif dataframes_in_group != 0:
+                raise PySparkValueError(
+                    errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
+                    messageParameters={"dataframes_in_group": 
str(dataframes_in_group)},
+                )
+
+    def dump_stream(self, iterator, stream):
+        """
+        Flatten the (dataframes_generator, arrow_type) tuples by iterating 
over each generator.
+        This allows the iterator UDF to stream results without materializing 
all DataFrames.
+        """
+        # Flatten: (dataframes_generator, arrow_type) -> (df, arrow_type), 
(df, arrow_type), ...
+        flattened_iter = (
+            (df, arrow_type) for dataframes_gen, arrow_type in iterator for df 
in dataframes_gen
+        )
+
+        # Convert each (df, arrow_type) to the format expected by parent's 
dump_stream
+        series_iter = ([(df, arrow_type)] for df, arrow_type in flattened_iter)
+
+        super(GroupPandasIterUDFSerializer, self).dump_stream(series_iter, 
stream)
+
+    def __repr__(self):
+        return "GroupPandasIterUDFSerializer"
+
+
 class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
     """
     Serializes pyarrow.RecordBatch data with Arrow streaming format.
diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index c184e0dc5668..18858ab0cf68 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -32,6 +32,9 @@ if TYPE_CHECKING:
         ArrowGroupedMapIterUDFType,
         ArrowGroupedMapUDFType,
         ArrowGroupedMapFunction,
+        PandasGroupedMapFunction,
+        PandasGroupedMapUDFType,
+        PandasGroupedMapIterUDFType,
     )
 
 
@@ -394,6 +397,98 @@ def infer_group_arrow_eval_type_from_func(
         return None
 
 
+def infer_group_pandas_eval_type(
+    sig: Signature,
+    type_hints: Dict[str, Any],
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+    from pyspark.sql.pandas.functions import PythonEvalType
+
+    require_minimum_pandas_version()
+
+    import pandas as pd
+
+    annotations = {}
+    for param in sig.parameters.values():
+        if param.annotation is not param.empty:
+            annotations[param.name] = type_hints.get(param.name, 
param.annotation)
+
+    # Check if all arguments have type hints
+    parameters_sig = [
+        annotations[parameter] for parameter in sig.parameters if parameter in 
annotations
+    ]
+    if len(parameters_sig) != len(sig.parameters):
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "all parameters", "sig": str(sig)},
+        )
+
+    # Check if the return has a type hint
+    return_annotation = type_hints.get("return", sig.return_annotation)
+    if sig.empty is return_annotation:
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "the return type", "sig": str(sig)},
+        )
+
+    # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+    is_iterator_dataframe = (
+        len(parameters_sig) == 1
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[0],
+            parameter_check_func=lambda t: t == pd.DataFrame,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+        )
+    )
+    # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+    is_iterator_dataframe_with_keys = (
+        len(parameters_sig) == 2
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[1],
+            parameter_check_func=lambda t: t == pd.DataFrame,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
+        )
+    )
+
+    if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+        return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
+    # pd.DataFrame -> pd.DataFrame
+    is_dataframe = (
+        len(parameters_sig) == 1
+        and parameters_sig[0] == pd.DataFrame
+        and return_annotation == pd.DataFrame
+    )
+    # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
+    is_dataframe_with_keys = (
+        len(parameters_sig) == 2
+        and parameters_sig[1] == pd.DataFrame
+        and return_annotation == pd.DataFrame
+    )
+    if is_dataframe or is_dataframe_with_keys:
+        return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
+    return None
+
+
+def infer_group_pandas_eval_type_from_func(
+    f: "PandasGroupedMapFunction",
+) -> Optional[Union["PandasGroupedMapUDFType", "PandasGroupedMapIterUDFType"]]:
+    argspec = getfullargspec(f)
+    if len(argspec.annotations) > 0:
+        try:
+            type_hints = get_type_hints(f)
+        except NameError:
+            type_hints = {}
+
+        return infer_group_pandas_eval_type(signature(f), type_hints)
+    else:
+        return None
+
+
 def check_tuple_annotation(
     annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
 ) -> bool:
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 0e922d072871..991fe67ea1f5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -21,7 +21,7 @@ import logging
 
 from collections import OrderedDict
 from decimal import Decimal
-from typing import cast
+from typing import cast, Iterator, Tuple, Any
 
 from pyspark.sql import Row, functions as sf
 from pyspark.sql.functions import (
@@ -1024,6 +1024,399 @@ class ApplyInPandasTestsMixin:
             ],
         )
 
+    def test_apply_in_pandas_iterator_basic(self):
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        def sum_func(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            total = 0
+            for batch in batches:
+                total += batch["v"].sum()
+            yield pd.DataFrame({"v": [total]})
+
+        result = df.groupby("id").applyInPandas(sum_func, schema="v 
double").orderBy("v").collect()
+        self.assertEqual(len(result), 2)
+        self.assertEqual(result[0][0], 3.0)
+        self.assertEqual(result[1][0], 18.0)
+
+    def test_apply_in_pandas_iterator_with_keys(self):
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+        )
+
+        def sum_func(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            total = 0
+            for batch in batches:
+                total += batch["v"].sum()
+            yield pd.DataFrame({"id": [key[0]], "v": [total]})
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(sum_func, schema="id long, v double")
+            .orderBy("id")
+            .collect()
+        )
+        self.assertEqual(len(result), 2)
+        self.assertEqual(result[0][0], 1)
+        self.assertEqual(result[0][1], 3.0)
+        self.assertEqual(result[1][0], 2)
+        self.assertEqual(result[1][1], 18.0)
+
+    def test_apply_in_pandas_iterator_batch_slicing(self):
+        df = self.spark.range(10000000).select(
+            (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+        )
+        cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+        df = df.withColumns(cols)
+
+        def min_max_v(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Collect all batches to compute min/max across the entire group
+            all_data = []
+            key_val = None
+            for batch in batches:
+                all_data.append(batch)
+                if key_val is None:
+                    key_val = batch.key.iloc[0]
+
+            combined = pd.concat(all_data, ignore_index=True)
+            assert len(combined) == 10000000 / 2, len(combined)
+
+            yield pd.DataFrame(
+                {
+                    "key": [key_val],
+                    "min": [combined.v.min()],
+                    "max": [combined.v.max()],
+                }
+            )
+
+        expected = (
+            df.groupby("key")
+            .agg(
+                sf.min("v").alias("min"),
+                sf.max("v").alias("max"),
+            )
+            .sort("key")
+        ).collect()
+
+        for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 
1048576)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result = (
+                        df.groupBy("key")
+                        .applyInPandas(min_max_v, "key long, min long, max 
long")
+                        .sort("key")
+                    ).collect()
+
+                    self.assertEqual(expected, result)
+
+    def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
+        df = self.spark.range(10000000).select(
+            (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+        )
+        cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
+        df = df.withColumns(cols)
+
+        def min_max_v(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            # Collect all batches to compute min/max across the entire group
+            all_data = []
+            for batch in batches:
+                all_data.append(batch)
+
+            combined = pd.concat(all_data, ignore_index=True)
+            assert len(combined) == 10000000 / 2, len(combined)
+
+            yield pd.DataFrame(
+                {
+                    "key": [key[0]],
+                    "min": [combined.v.min()],
+                    "max": [combined.v.max()],
+                }
+            )
+
+        expected = (
+            df.groupby("key").agg(sf.min("v").alias("min"), 
sf.max("v").alias("max")).sort("key")
+        ).collect()
+
+        for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 
1048576)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result = (
+                        df.groupBy("key")
+                        .applyInPandas(min_max_v, "key long, min long, max 
long")
+                        .sort("key")
+                    ).collect()
+
+                    self.assertEqual(expected, result)
+
+    def test_apply_in_pandas_iterator_multiple_output_batches(self):
+        df = self.spark.createDataFrame(
+            [(1, 1.0), (1, 2.0), (1, 3.0), (2, 4.0), (2, 5.0), (2, 6.0)], 
("id", "v")
+        )
+
+        def split_and_yield(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Yield multiple output batches for each input batch
+            for batch in batches:
+                for _, row in batch.iterrows():
+                    # Yield each row as a separate batch to test multiple 
yields
+                    yield pd.DataFrame(
+                        {"id": [row["id"]], "v": [row["v"]], "v_doubled": 
[row["v"] * 2]}
+                    )
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(split_and_yield, schema="id long, v double, 
v_doubled double")
+            .orderBy("id", "v")
+            .collect()
+        )
+
+        expected = [
+            Row(id=1, v=1.0, v_doubled=2.0),
+            Row(id=1, v=2.0, v_doubled=4.0),
+            Row(id=1, v=3.0, v_doubled=6.0),
+            Row(id=2, v=4.0, v_doubled=8.0),
+            Row(id=2, v=5.0, v_doubled=10.0),
+            Row(id=2, v=6.0, v_doubled=12.0),
+        ]
+        self.assertEqual(result, expected)
+
+    def test_apply_in_pandas_iterator_filter_multiple_batches(self):
+        df = self.spark.createDataFrame(
+            [(1, i * 1.0) for i in range(20)] + [(2, i * 1.0) for i in 
range(20)], ("id", "v")
+        )
+
+        def filter_and_yield(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Yield filtered results from each batch
+            for batch in batches:
+                # Filter even values and yield
+                even_batch = batch[batch["v"] % 2 == 0]
+                if not even_batch.empty:
+                    yield even_batch
+
+                # Filter odd values and yield separately
+                odd_batch = batch[batch["v"] % 2 == 1]
+                if not odd_batch.empty:
+                    yield odd_batch
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(filter_and_yield, schema="id long, v double")
+            .orderBy("id", "v")
+            .collect()
+        )
+
+        # Verify all 40 rows are present (20 per group)
+        self.assertEqual(len(result), 40)
+
+        # Verify group 1 has all values 0-19
+        group1 = [row for row in result if row[0] == 1]
+        self.assertEqual(len(group1), 20)
+        self.assertEqual([row[1] for row in group1], [float(i) for i in 
range(20)])
+
+        # Verify group 2 has all values 0-19
+        group2 = [row for row in result if row[0] == 2]
+        self.assertEqual(len(group2), 20)
+        self.assertEqual([row[1] for row in group2], [float(i) for i in 
range(20)])
+
+    def test_apply_in_pandas_iterator_with_keys_multiple_batches(self):
+        df = self.spark.createDataFrame(
+            [
+                (1, "a", 1.0),
+                (1, "b", 2.0),
+                (1, "c", 3.0),
+                (2, "d", 4.0),
+                (2, "e", 5.0),
+                (2, "f", 6.0),
+            ],
+            ("id", "name", "v"),
+        )
+
+        def process_with_key(
+            key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
+        ) -> Iterator[pd.DataFrame]:
+            # Yield multiple processed batches, including the key in each 
output
+            for batch in batches:
+                # Split batch and yield multiple output batches
+                for chunk_size in [1, 2]:
+                    for i in range(0, len(batch), chunk_size):
+                        chunk = batch.iloc[i : i + chunk_size]
+                        if not chunk.empty:
+                            result = chunk.assign(id=key[0], 
total=chunk["v"].sum())
+                            yield result[["id", "name", "total"]]
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(process_with_key, schema="id long, name string, 
total double")
+            .orderBy("id", "name")
+            .collect()
+        )
+
+        # Verify we get results (may have duplicates due to splitting)
+        self.assertTrue(len(result) > 6)
+
+        # Verify all original names are present
+        names = [row[1] for row in result]
+        self.assertIn("a", names)
+        self.assertIn("b", names)
+        self.assertIn("c", names)
+        self.assertIn("d", names)
+        self.assertIn("e", names)
+        self.assertIn("f", names)
+
+        # Verify keys are correct
+        for row in result:
+            if row[1] in ["a", "b", "c"]:
+                self.assertEqual(row[0], 1)
+            else:
+                self.assertEqual(row[0], 2)
+
+    def test_apply_in_pandas_iterator_process_multiple_input_batches(self):
+        import builtins
+
+        # Create large dataset to trigger batch slicing
+        df = self.spark.range(100000).select(
+            (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
+        )
+
+        def process_batches_progressively(
+            batches: Iterator[pd.DataFrame],
+        ) -> Iterator[pd.DataFrame]:
+            # Process each input batch and yield output immediately
+            batch_count = 0
+            for batch in batches:
+                batch_count += 1
+                # Yield a summary for each input batch processed
+                yield pd.DataFrame(
+                    {
+                        "key": [batch.key.iloc[0]],
+                        "batch_num": [batch_count],
+                        "count": [len(batch)],
+                        "sum": [batch.v.sum()],
+                    }
+                )
+
+        # Use small batch size to force multiple input batches
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": 10000,
+            }
+        ):
+            result = (
+                df.groupBy("key")
+                .applyInPandas(
+                    process_batches_progressively,
+                    schema="key long, batch_num long, count long, sum long",
+                )
+                .orderBy("key", "batch_num")
+                .collect()
+            )
+
+        # Verify we got multiple batches per group (100000/2 = 50000 rows per 
group)
+        # With maxRecordsPerBatch=10000, should get 5 batches per group
+        group_0_batches = [r for r in result if r[0] == 0]
+        group_1_batches = [r for r in result if r[0] == 1]
+
+        # Verify multiple batches were processed
+        self.assertGreater(len(group_0_batches), 1)
+        self.assertGreater(len(group_1_batches), 1)
+
+        # Verify the sum across all batches equals expected total (using 
Python's built-in sum)
+        group_0_sum = builtins.sum(r[3] for r in group_0_batches)
+        group_1_sum = builtins.sum(r[3] for r in group_1_batches)
+
+        # Expected: sum of even numbers 0,2,4,...,99998
+        expected_even_sum = builtins.sum(range(0, 100000, 2))
+        expected_odd_sum = builtins.sum(range(1, 100000, 2))
+
+        self.assertEqual(group_0_sum, expected_even_sum)
+        self.assertEqual(group_1_sum, expected_odd_sum)
+
+    def test_apply_in_pandas_iterator_streaming_aggregation(self):
+        # Create dataset with multiple batches per group
+        df = self.spark.range(50000).select(
+            (sf.col("id") % 3).alias("key"),
+            (sf.col("id") % 100).alias("category"),
+            sf.col("id").alias("value"),
+        )
+
+        def streaming_aggregate(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            # Maintain running aggregates and yield intermediate results
+            running_sum = 0
+            running_count = 0
+
+            for batch in batches:
+                # Update running aggregates
+                running_sum += batch.value.sum()
+                running_count += len(batch)
+
+                # Yield current stats after processing each batch
+                yield pd.DataFrame(
+                    {
+                        "key": [batch.key.iloc[0]],
+                        "running_count": [running_count],
+                        "running_avg": [running_sum / running_count],
+                    }
+                )
+
+        # Force multiple batches with small batch size
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.maxRecordsPerBatch": 5000,
+            }
+        ):
+            result = (
+                df.groupBy("key")
+                .applyInPandas(
+                    streaming_aggregate, schema="key long, running_count long, 
running_avg double"
+                )
+                .collect()
+            )
+
+        # Verify we got multiple rows per group (one per input batch)
+        for key_val in [0, 1, 2]:
+            key_results = [r for r in result if r[0] == key_val]
+            # Should have multiple batches
+            # (50000/3 ≈ 16667 rows per group, with 5000 per batch = ~4 
batches)
+            self.assertGreater(len(key_results), 1, f"Expected multiple 
batches for key {key_val}")
+
+            # Verify running_count increases monotonically
+            counts = [r[1] for r in key_results]
+            for i in range(1, len(counts)):
+                self.assertGreater(
+                    counts[i], counts[i - 1], "Running count should increase 
with each batch"
+                )
+
+    def test_apply_in_pandas_iterator_partial_iteration(self):
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
2}):
+
+            def func(batches: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+                # Only consume the first batch from the iterator
+                first = next(batches)
+                yield pd.DataFrame({"value": first["id"] % 4})
+
+            df = self.spark.range(20)
+            grouped_df = df.groupBy((col("id") % 4).cast("int"))
+
+            # Should get two records for each group (first batch only)
+            expected = [Row(value=x) for x in [0, 0, 1, 1, 2, 2, 3, 3]]
+
+            actual = grouped_df.applyInPandas(func, "value long").collect()
+            self.assertEqual(actual, expected)
+
 
 class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 23fceb746114..017698f318d5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -183,8 +183,9 @@ class PandasUDFTestsMixin:
                 exception=pe.exception,
                 errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
                 messageParameters={
-                    "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF "
-                    "or SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
+                    "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
+                    "SQL_GROUPED_MAP_PANDAS_ITER_UDF or "
+                    "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
                     "return_type": "DoubleType()",
                 },
             )
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
index a436b71c123d..09de2a2e3198 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py
@@ -26,9 +26,10 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
-from pyspark.sql.pandas.typehints import infer_eval_type
+from pyspark.sql.pandas.typehints import infer_eval_type, 
infer_group_pandas_eval_type
 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
 from pyspark.sql import Row
+from pyspark.util import PythonEvalType
 
 if have_pandas:
     import pandas as pd
@@ -186,6 +187,49 @@ class PandasUDFTypeHintsTests(ReusedSQLTestCase):
             PandasUDFType.GROUPED_AGG,
         )
 
+    def test_type_annotation_group_map(self):
+        # pd.DataFrame -> pd.DataFrame
+        def func(col: pd.DataFrame) -> pd.DataFrame:
+            pass
+
+        self.assertEqual(
+            infer_group_pandas_eval_type(signature(func), 
get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        )
+
+        # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
+        def func(key: Tuple, col: pd.DataFrame) -> pd.DataFrame:
+            pass
+
+        self.assertEqual(
+            infer_group_pandas_eval_type(signature(func), 
get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        )
+
+        # Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+        def func(col: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
+            pass
+
+        self.assertEqual(
+            infer_group_pandas_eval_type(signature(func), 
get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
+        )
+
+        # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
+        def func(key: Tuple, col: Iterator[pd.DataFrame]) -> 
Iterator[pd.DataFrame]:
+            pass
+
+        self.assertEqual(
+            infer_group_pandas_eval_type(signature(func), 
get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
+        )
+
+        # Should return None for unsupported signatures
+        def func(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
+            pass
+
+        self.assertEqual(infer_group_pandas_eval_type(signature(func), 
get_type_hints(func)), None)
+
     def test_type_annotation_negative(self):
         def func(col: str) -> pd.Series:
             pass
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b28bccb04bc7..37aa30cc279f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -238,6 +238,7 @@ class UserDefinedFunction:
                 )
         elif (
             evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+            or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
             or evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
         ):
             if isinstance(returnType, StructType):
@@ -256,6 +257,7 @@ class UserDefinedFunction:
                     errorClass="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
                     messageParameters={
                         "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
+                        "SQL_GROUPED_MAP_PANDAS_ITER_UDF or "
                         "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
                         "return_type": str(returnType),
                     },
@@ -282,7 +284,10 @@ class UserDefinedFunction:
                         "return_type": str(returnType),
                     },
                 )
-        elif evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+        elif (
+            evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+            or evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+        ):
             if isinstance(returnType, StructType):
                 try:
                     to_arrow_type(returnType)
@@ -298,7 +303,7 @@ class UserDefinedFunction:
                 raise PySparkTypeError(
                     errorClass="INVALID_RETURN_TYPE_FOR_ARROW_UDF",
                     messageParameters={
-                        "eval_type": "SQL_GROUPED_MAP_ARROW_UDF",
+                        "eval_type": "SQL_GROUPED_MAP_ARROW_UDF or 
SQL_GROUPED_MAP_ARROW_ITER_UDF",
                         "return_type": str(returnType),
                     },
                 )
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index f94fc73b6435..f633ed699ee2 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -62,6 +62,7 @@ if typing.TYPE_CHECKING:
         ArrowGroupedMapUDFType,
         ArrowGroupedMapIterUDFType,
         ArrowCogroupedMapUDFType,
+        PandasGroupedMapIterUDFType,
         PandasGroupedMapUDFTransformWithStateType,
         PandasGroupedMapUDFTransformWithStateInitStateType,
         GroupedMapUDFTransformWithStateType,
@@ -652,6 +653,7 @@ class PythonEvalType:
         214
     )
     SQL_GROUPED_MAP_ARROW_ITER_UDF: "ArrowGroupedMapIterUDFType" = 215
+    SQL_GROUPED_MAP_PANDAS_ITER_UDF: "PandasGroupedMapIterUDFType" = 216
 
     # Arrow UDFs
     SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b232b30c5420..09c6a40a33db 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -60,6 +60,7 @@ from pyspark.sql.pandas.serializers import (
     CogroupPandasUDFSerializer,
     ArrowStreamUDFSerializer,
     ApplyInPandasWithStateSerializer,
+    GroupPandasIterUDFSerializer,
     GroupPandasUDFSerializer,
     TransformWithStateInPandasSerializer,
     TransformWithStateInPandasInitStateSerializer,
@@ -748,6 +749,39 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
     return lambda k, v: [(wrapped(k, v), arrow_return_type)]
 
 
+def wrap_grouped_map_pandas_iter_udf(f, return_type, argspec, runner_conf):
+    _use_large_var_types = use_large_var_types(runner_conf)
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
+    def wrapped(key_series_list, value_series_gen):
+        import pandas as pd
+
+        # value_series_gen is a generator that yields multiple lists of Series 
(one per batch)
+        # Convert each list of Series into a DataFrame
+        def dataframe_iter():
+            for value_series in value_series_gen:
+                yield pd.concat(value_series, axis=1)
+
+        # Extract key from the first batch
+        if len(argspec.args) == 1:
+            result = f(dataframe_iter())
+        elif len(argspec.args) == 2:
+            # key_series_list is a list of Series for the key columns from the 
first batch
+            key = tuple(s[0] for s in key_series_list)
+            result = f(key, dataframe_iter())
+
+        def verify_element(df):
+            verify_pandas_result(
+                df, return_type, _assign_cols_by_name, 
truncate_return_schema=False
+            )
+            return df
+
+        yield from map(verify_element, result)
+
+    arrow_return_type = to_arrow_type(return_type, _use_large_var_types)
+    return lambda k, v: (wrapped(k, v), arrow_return_type)
+
+
 def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
     def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
         result_iter = f(stateful_processor_api_client, mode, key, 
value_series_gen)
@@ -1262,6 +1296,11 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+        argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
+        return args_offsets, wrap_grouped_map_pandas_iter_udf(
+            func, return_type, argspec, runner_conf
+        )
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
@@ -2565,6 +2604,7 @@ def read_udfs(pickleSer, infile, eval_type):
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
@@ -2635,6 +2675,10 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = GroupPandasUDFSerializer(
                 timezone, safecheck, _assign_cols_by_name, 
int_to_decimal_coercion_enabled
             )
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+            ser = GroupPandasIterUDFSerializer(
+                timezone, safecheck, _assign_cols_by_name, 
int_to_decimal_coercion_enabled
+            )
         elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
             ser = CogroupArrowUDFSerializer(_assign_cols_by_name)
         elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
@@ -2894,6 +2938,36 @@ def read_udfs(pickleSer, infile, eval_type):
             vals = [a[o] for o in parsed_offsets[0][1]]
             return f(keys, vals)
 
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF:
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0, 
profiler=profiler
+        )
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        # Create mapper similar to Arrow iterator:
+        # `a` is an iterator of Series lists (one list per batch, containing 
all columns)
+        # Materialize first batch to get keys, then create generator for value 
batches
+        def mapper(a):
+            import itertools
+
+            series_iter = iter(a)
+            # Need to materialize the first series list to get the keys
+            first_series_list = next(series_iter)
+
+            keys = [first_series_list[o] for o in parsed_offsets[0][0]]
+            value_series_gen = (
+                [series_list[o] for o in parsed_offsets[0][1]]
+                for series_list in itertools.chain((first_series_list,), 
series_iter)
+            )
+
+            return f(keys, value_series_gen)
+
     elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
         # We assume there is only one UDF here because grouped map doesn't
         # support combining multiple UDFs.
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 41efe8db842f..8f8e6261066f 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -690,7 +690,8 @@ class SparkConnectPlanner(
           .groupBy(cols: _*)
 
         pythonUdf.evalType match {
-          case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
+          case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF |
+              PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF =>
             group.flatMapGroupsInPandas(Column(pythonUdf)).logicalPlan
 
           case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF |
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
index 674c206c96a9..bd7b3348b9f0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
@@ -278,7 +278,8 @@ class RelationalGroupedDataset protected[sql](
    */
   private[sql] def flatMapGroupsInPandas(column: Column): DataFrame = {
     val expr = column.expr.asInstanceOf[PythonUDF]
-    require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+    require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF ||
+      expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
       "Must pass a grouped map pandas udf")
     require(expr.dataType.isInstanceOf[StructType],
       s"The returnType of the udf must be a ${StructType.simpleString}")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 82c2f3fab200..765181d7e331 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 
@@ -46,7 +45,7 @@ case class FlatMapGroupsInPandasExec(
     child: SparkPlan)
   extends FlatMapGroupsInBatchExec {
 
-  protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+  protected val pythonEvalType: Int = func.asInstanceOf[PythonUDF].evalType
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
FlatMapGroupsInPandasExec =
     copy(child = newChild)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to