This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e122017df62 [SPARK-49547][SQL][PYTHON] Add iterator of `RecordBatch` 
API to `applyInArrow`
9e122017df62 is described below

commit 9e122017df62b5588693eaaeb7d59225370ed1ec
Author: Adam Binford <[email protected]>
AuthorDate: Sat Sep 27 08:34:59 2025 +0800

    [SPARK-49547][SQL][PYTHON] Add iterator of `RecordBatch` API to 
`applyInArrow`
    
    <!--
    Thanks for sending a pull request!  Here are some tips for you:
      1. If this is your first time, please read our contributor guidelines: 
https://spark.apache.org/contributing.html
      2. Ensure you have added or run the appropriate tests for your PR: 
https://spark.apache.org/developer-tools.html
      3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., 
'[WIP][SPARK-XXXX] Your PR title ...'.
      4. Be sure to keep the PR description updated to reflect all changes.
      5. Please write your PR title to summarize what this PR proposes.
      6. If possible, provide a concise example to reproduce the issue for a 
faster review.
      7. If you want to add a new configuration, please read the guideline 
first for naming configurations in
         
'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
      8. If you want to add or modify an error type or message, please read the 
guideline first in
         'common/utils/src/main/resources/error/README.md'.
    -->
    
    ### What changes were proposed in this pull request?
    <!--
    Please clarify what changes you are proposing. The purpose of this section 
is to outline the changes and how this PR fixes the issue.
    If possible, please consider writing useful notes for better and faster 
reviews in your PR. See the examples below.
      1. If you refactor some codes with changing classes, showing the class 
hierarchy will help reviewers.
      2. If you fix some SQL features, you can provide some references of other 
DBMSes.
      3. If there is design documentation, please add the link.
      4. If there is a discussion in the mailing list, please add the link.
    -->
    Add the option to `applyInArrow` to take a function that takes an iterator 
of `RecordBatch` and returns an iterator of `RecordBatch`. A new eval type is 
added `SQL_GROUPED_MAP_ARROW_ITER_UDF`, and is detected via type hints on the 
function.
    
    ### Why are the changes needed?
    <!--
    Please clarify why the changes are needed. For instance,
      1. If you propose a new API, clarify the use case for a new API.
      2. If you fix a bug, you can clarify why it is a bug.
    -->
    Having a single Table as input and a single Table as output requires 
collecting all inputs and outputs in memory for a single batch. This can 
require excessive memory for certain edge cases with large groups. Inputs and 
outputs already get serialized as record batches, so simply expose this lazy 
iterator directly instead of forcing materialization into a table.
    
    ### Does this PR introduce _any_ user-facing change?
    <!--
    Note that it means *any* user-facing change including all aspects such as 
the documentation fix.
    If yes, please clarify the previous behavior and the change this PR 
proposes - provide the console output, description and/or an example to show 
the behavior difference if possible.
    If possible, please also clarify if this is a user-facing change compared 
to the released Spark versions or within the unreleased branches such as master.
    If no, write 'No'.
    -->
    Yes, a new function signature supported by `applyInArrow`.
    
    Example:
    ```python
    import pyarrow as pa
    import pyarrow.compute as pc
    def sum_func(key: Tuple[pa.Scalar, ...], batches: Iterator[pa.RecordBatch]) 
-> Iterator[pa.RecordBatch]:
        total = 0
        for batch in batches:
            total += pc.sum(batch.column("v")).as_py()
        yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": 
[total]})
    
    df.groupby("id").applyInArrow(sum_func, schema="id long, v double").show()
    ```
    ```
    +---+----+
    | id|   v|
    +---+----+
    |  1| 3.0|
    |  2|18.0|
    +---+----+
    ```
    
    ### How was this patch tested?
    <!--
    If tests were added, say they were added here. Please make sure to add some 
test cases that check the changes thoroughly including negative and positive 
cases if possible.
    If it was tested in a way different from regular unit tests, please clarify 
how you tested step by step, ideally copy and paste-able, so that other 
reviewers can test and check, and descendants can verify in the future.
    If tests were not added, please describe why they were not added and/or why 
it was difficult to add.
    If benchmark tests were added, please run the benchmarks in GitHub Actions 
for the consistent environment, and the instructions could accord to: 
https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
    -->
    Updated existing UTs to test both Table signatures and RecordBatch 
signatures
    
    ### Was this patch authored or co-authored using generative AI tooling?
    <!--
    If generative AI tooling has been used in the process of authoring this 
patch, please include the
    phrase: 'Generated-by: ' followed by the name of the tool and its version.
    If no, write 'No'.
    Please refer to the [ASF Generative Tooling 
Guidance](https://www.apache.org/legal/generative-tooling.html) for details.
    -->
    No
    
    Closes #52440 from Kimahriman/apply-in-arrow-iter-eval.
    
    Authored-by: Adam Binford <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/api/python/PythonRunner.scala |   1 +
 python/pyspark/sql/connect/group.py                |  14 +-
 python/pyspark/sql/pandas/_typing/__init__.pyi     |  12 +-
 python/pyspark/sql/pandas/functions.py             |   2 +
 python/pyspark/sql/pandas/group_ops.py             |  89 ++++++++--
 python/pyspark/sql/pandas/serializers.py           |  19 +-
 python/pyspark/sql/pandas/typehints.py             |  91 ++++++++++
 .../sql/tests/arrow/test_arrow_grouped_map.py      | 195 +++++++++++++++------
 .../sql/tests/arrow/test_arrow_udf_typehints.py    |  25 ++-
 python/pyspark/util.py                             |   2 +
 python/pyspark/worker.py                           | 126 ++++++++++---
 .../sql/connect/planner/SparkConnectPlanner.scala  |   3 +-
 .../sql/classic/RelationalGroupedDataset.scala     |   3 +-
 .../python/FlatMapGroupsInArrowExec.scala          |  22 ++-
 14 files changed, 479 insertions(+), 125 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index ca4f6e56554e..8e5b7ef001b8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -66,6 +66,7 @@ private[spark] object PythonEvalType {
   val SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF = 212
   val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF = 213
   val SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF = 214
+  val SQL_GROUPED_MAP_ARROW_ITER_UDF = 215
 
   // Arrow UDFs
   val SQL_SCALAR_ARROW_UDF = 250
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 04f8f26ecf38..4bba2b1e0b71 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -35,6 +35,7 @@ from pyspark.util import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
 from pyspark.sql.pandas.group_ops import PandasCogroupedOps as 
PySparkPandasCogroupedOps
 from pyspark.sql.pandas.functions import _validate_vectorized_udf  # type: 
ignore[attr-defined]
+from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
 from pyspark.sql.types import NumericType, StructType
 
 import pyspark.sql.connect.plan as plan
@@ -472,13 +473,22 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
 
-        _validate_vectorized_udf(func, 
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF)
+        try:
+            # Try to infer the eval type from type hints
+            eval_type = infer_group_arrow_eval_type_from_func(func)
+        except Exception:
+            warnings.warn("Cannot infer the eval type from type hints. ", 
UserWarning)
+
+        if eval_type is None:
+            eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
+        _validate_vectorized_udf(func, eval_type)
         if isinstance(schema, str):
             schema = cast(StructType, self._df._session._parse_ddl(schema))
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
-            evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+            evalType=eval_type,
         )
 
         res = DataFrame(
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index d1e2b7aae6f8..cea44921069f 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -20,6 +20,7 @@ from typing import (
     Any,
     Callable,
     Iterable,
+    Iterator,
     NewType,
     Tuple,
     Type,
@@ -59,6 +60,7 @@ PandasGroupedMapUDFTransformWithStateType = Literal[211]
 PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]
 GroupedMapUDFTransformWithStateType = Literal[213]
 GroupedMapUDFTransformWithStateInitStateType = Literal[214]
+ArrowGroupedMapIterUDFType = Literal[215]
 
 # Arrow UDFs
 ArrowScalarUDFType = Literal[250]
@@ -430,10 +432,18 @@ PandasCogroupedMapFunction = Union[
     Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike],
 ]
 
-ArrowGroupedMapFunction = Union[
+ArrowGroupedMapTableFunction = Union[
     Callable[[pyarrow.Table], pyarrow.Table],
     Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
 ]
+ArrowGroupedMapIterFunction = Union[
+    Callable[[Iterator[pyarrow.RecordBatch]], Iterator[pyarrow.RecordBatch]],
+    Callable[
+        [Tuple[pyarrow.Scalar, ...], Iterator[pyarrow.RecordBatch]], 
Iterator[pyarrow.RecordBatch]
+    ],
+]
+ArrowGroupedMapFunction = Union[ArrowGroupedMapTableFunction, 
ArrowGroupedMapIterFunction]
+
 ArrowCogroupedMapFunction = Union[
     Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
     Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], 
pyarrow.Table],
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 0c571ba0acfb..55a9af2fdee4 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -700,6 +700,7 @@ def vectorized_udf(
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
         PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         None,
     ]:  # None means it should infer the type from type hints.
@@ -779,6 +780,7 @@ def _validate_vectorized_udf(f, evalType, kind: str = 
"pandas") -> int:
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
         PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         PythonEvalType.SQL_ARROW_BATCHED_UDF,
     ]:
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 08e795e99a1d..07d78a6ce6d8 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -22,6 +22,7 @@ from pyspark.errors import PySparkTypeError
 from pyspark.util import PythonEvalType
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
 from pyspark.sql.streaming.state import GroupStateTimeout
 from pyspark.sql.streaming.stateful_processor import StatefulProcessor
 from pyspark.sql.types import StructType
@@ -703,27 +704,33 @@ class PandasGroupedOpsMixin:
         Maps each group of the current :class:`DataFrame` using an Arrow udf 
and returns the result
         as a `DataFrame`.
 
-        The function should take a `pyarrow.Table` and return another
-        `pyarrow.Table`. Alternatively, the user can pass a function that takes
-        a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`.
-        For each group, all columns are passed together as a `pyarrow.Table`
-        to the user-function and the returned `pyarrow.Table` are combined as a
-        :class:`DataFrame`.
+        The function can take one of two forms: It can take a `pyarrow.Table` 
and return a
+        `pyarrow.Table`, or it can take an iterator of `pyarrow.RecordBatch` 
and yield
+        `pyarrow.RecordBatch`. Alternatively each form can take a tuple of 
`pyarrow.Scalar`
+        as the first argument in addition to the input type above. For each 
group, all columns
+        are passed together in the `pyarrow.Table` or `pyarrow.RecordBatch`, 
and the returned
+        `pyarrow.Table` or iterator of `pyarrow.RecordBatch` are combined as a 
:class:`DataFrame`.
 
         The `schema` should be a :class:`StructType` describing the schema of 
the returned
-        `pyarrow.Table`. The column labels of the returned `pyarrow.Table` 
must either match
-        the field names in the defined schema if specified as strings, or 
match the
-        field data types by position if not strings, e.g. integer indices.
-        The length of the returned `pyarrow.Table` can be arbitrary.
+        `pyarrow.Table` or `pyarrow.RecordBatch`. The column labels of the 
returned `pyarrow.Table`
+        or `pyarrow.RecordBatch` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings, e.g.
+        integer indices. The length of the returned `pyarrow.Table` or 
iterator of
+        `pyarrow.RecordBatch` can be arbitrary.
 
         .. versionadded:: 4.0.0
 
+        .. versionchanged:: 4.1.0
+           Added support for an iterator of `pyarrow.RecordBatch` API.
+
         Parameters
         ----------
         func : function
-            a Python native function that takes a `pyarrow.Table` and outputs a
-            `pyarrow.Table`, or that takes one tuple (grouping keys) and a
-            `pyarrow.Table` and outputs a `pyarrow.Table`.
+            a Python native function that either takes a `pyarrow.Table` and 
outputs a
+            `pyarrow.Table` or takes an iterator of `pyarrow.RecordBatch` and 
yields
+            `pyarrow.RecordBatch`. Additionally, each form can take a tuple of 
grouping keys
+            as the first argument, with the `pyarrow.Table` or iterator of 
`pyarrow.RecordBatch`
+            as the second argument.
         schema : :class:`pyspark.sql.types.DataType` or str
             the return type of the `func` in PySpark. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
@@ -752,6 +759,28 @@ class PandasGroupedOpsMixin:
         |  2| 1.1094003924504583|
         +---+-------------------+
 
+        The function can also take and return an iterator of 
`pyarrow.RecordBatch` using type
+        hints.
+
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))  # doctest: +SKIP
+        >>> def sum_func(
+        ...     batches: Iterator[pyarrow.RecordBatch]
+        ... ) -> Iterator[pyarrow.RecordBatch]:  # doctest: +SKIP
+        ...     total = 0
+        ...     for batch in batches:
+        ...         total += pc.sum(batch.column("v")).as_py()
+        ...     yield pyarrow.RecordBatch.from_pydict({"v": [total]})
+        >>> df.groupby("id").applyInArrow(
+        ...     sum_func, schema="v double").show()  # doctest: +SKIP
+        +----+
+        |   v|
+        +----+
+        | 3.0|
+        |18.0|
+        +----+
+
         Alternatively, the user can pass a function that takes two arguments.
         In this case, the grouping key(s) will be passed as the first argument 
and the data will
         be passed as the second argument. The grouping key(s) will be passed 
as a tuple of Arrow
@@ -796,11 +825,28 @@ class PandasGroupedOpsMixin:
         |  2|          2| 3.0|
         +---+-----------+----+
 
+        >>> def sum_func(
+        ...     key: Tuple[pyarrow.Scalar, ...], batches: 
Iterator[pyarrow.RecordBatch]
+        ... ) -> Iterator[pyarrow.RecordBatch]:  # doctest: +SKIP
+        ...     total = 0
+        ...     for batch in batches:
+        ...         total += pc.sum(batch.column("v")).as_py()
+        ...     yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], 
"v": [total]})
+        >>> df.groupby("id").applyInArrow(
+        ...     sum_func, schema="id long, v double").show()  # doctest: +SKIP
+        +---+----+
+        | id|   v|
+        +---+----+
+        |  1| 3.0|
+        |  2|18.0|
+        +---+----+
+
         Notes
         -----
-        This function requires a full shuffle. All the data of a group will be 
loaded
-        into memory, so the user should be aware of the potential OOM risk if 
data is skewed
-        and certain groups are too large to fit in memory.
+        This function requires a full shuffle. If using the `pyarrow.Table` 
API, all data of a
+        group will be loaded into memory, so the user should be aware of the 
potential OOM risk
+        if data is skewed and certain groups are too large to fit in memory, 
and can use the
+        iterator of `pyarrow.RecordBatch` API to mitigate this.
 
         This API is unstable, and for developers.
 
@@ -813,9 +859,18 @@ class PandasGroupedOpsMixin:
 
         assert isinstance(self, GroupedData)
 
+        try:
+            # Try to infer the eval type from type hints
+            eval_type = infer_group_arrow_eval_type_from_func(func)
+        except Exception:
+            warnings.warn("Cannot infer the eval type from type hints. ", 
UserWarning)
+
+        if eval_type is None:
+            eval_type = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
         # The usage of the pandas_udf is internal so type checking is disabled.
         udf = pandas_udf(
-            func, returnType=schema, 
functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+            func, returnType=schema, functionType=eval_type
         )  # type: ignore[call-overload]
         df = self._df
         udf_column = udf(*[df[col] for col in df.columns])
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 801a87c06cc5..35eeb11861a6 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -21,7 +21,7 @@ Serializers for PyArrow and pandas conversions. See 
`pyspark.serializers` for mo
 
 from decimal import Decimal
 from itertools import groupby
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Iterator, Optional
 
 import pyspark
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
@@ -1116,19 +1116,22 @@ class 
GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
         """
         import pyarrow as pa
 
+        def process_group(batches: "Iterator[pa.RecordBatch]"):
+            for batch in batches:
+                struct = batch.column(0)
+                yield pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))
+
         dataframes_in_group = None
 
         while dataframes_in_group is None or dataframes_in_group > 0:
             dataframes_in_group = read_int(stream)
 
             if dataframes_in_group == 1:
-                structs = [
-                    batch.column(0) for batch in 
ArrowStreamSerializer.load_stream(self, stream)
-                ]
-                yield [
-                    pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))
-                    for struct in structs
-                ]
+                batch_iter = 
process_group(ArrowStreamSerializer.load_stream(self, stream))
+                yield batch_iter
+                # Make sure the batches are fully iterated before getting the 
next group
+                for _ in batch_iter:
+                    pass
 
             elif dataframes_in_group != 0:
                 raise PySparkValueError(
diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index 610bd1df40ac..c184e0dc5668 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -29,6 +29,9 @@ if TYPE_CHECKING:
         ArrowScalarUDFType,
         ArrowScalarIterUDFType,
         ArrowGroupedAggUDFType,
+        ArrowGroupedMapIterUDFType,
+        ArrowGroupedMapUDFType,
+        ArrowGroupedMapFunction,
     )
 
 
@@ -303,6 +306,94 @@ def infer_eval_type_for_udf(  # type: 
ignore[no-untyped-def]
         return None
 
 
+def infer_group_arrow_eval_type(
+    sig: Signature,
+    type_hints: Dict[str, Any],
+) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
+    from pyspark.sql.pandas.functions import PythonEvalType
+
+    require_minimum_pyarrow_version()
+
+    import pyarrow as pa
+
+    annotations = {}
+    for param in sig.parameters.values():
+        if param.annotation is not param.empty:
+            annotations[param.name] = type_hints.get(param.name, 
param.annotation)
+
+    # Check if all arguments have type hints
+    parameters_sig = [
+        annotations[parameter] for parameter in sig.parameters if parameter in 
annotations
+    ]
+    if len(parameters_sig) != len(sig.parameters):
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "all parameters", "sig": str(sig)},
+        )
+
+    # Check if the return has a type hint
+    return_annotation = type_hints.get("return", sig.return_annotation)
+    if sig.empty is return_annotation:
+        raise PySparkValueError(
+            errorClass="TYPE_HINT_SHOULD_BE_SPECIFIED",
+            messageParameters={"target": "the return type", "sig": str(sig)},
+        )
+
+    # Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
+    is_iterator_batch = (
+        len(parameters_sig) == 1
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[0],
+            parameter_check_func=lambda t: t == pa.RecordBatch,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == 
pa.RecordBatch
+        )
+    )
+    # Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> 
Iterator[pa.RecordBatch]
+    is_iterator_batch_with_keys = (
+        len(parameters_sig) == 2
+        and check_iterator_annotation(  # Iterator
+            parameters_sig[1],
+            parameter_check_func=lambda t: t == pa.RecordBatch,
+        )
+        and check_iterator_annotation(
+            return_annotation, parameter_check_func=lambda t: t == 
pa.RecordBatch
+        )
+    )
+
+    if is_iterator_batch or is_iterator_batch_with_keys:
+        return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+
+    # pa.Table -> pa.Table
+    is_table = (
+        len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and 
return_annotation == pa.Table
+    )
+    # Tuple[pa.Scalar, ...], pa.Table -> pa.Table
+    is_table_with_keys = (
+        len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and 
return_annotation == pa.Table
+    )
+    if is_table or is_table_with_keys:
+        return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
+    return None
+
+
+def infer_group_arrow_eval_type_from_func(
+    f: "ArrowGroupedMapFunction",
+) -> Optional[Union["ArrowGroupedMapUDFType", "ArrowGroupedMapIterUDFType"]]:
+    argspec = getfullargspec(f)
+    if len(argspec.annotations) > 0:
+        try:
+            type_hints = get_type_hints(f)
+        except NameError:
+            type_hints = {}
+
+        return infer_group_arrow_eval_type(signature(f), type_hints)
+    else:
+        return None
+
+
 def check_tuple_annotation(
     annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
 ) -> bool:
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 4de88d097a68..e1cd507737cf 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -14,8 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import inspect
 import os
 import time
+from typing import Iterator, Tuple
 import unittest
 
 from pyspark.errors import PythonException
@@ -33,6 +35,25 @@ if have_pyarrow:
     import pyarrow.compute as pc
 
 
+def function_variations(func):
+    yield func
+    num_args = len(inspect.getfullargspec(func).args)
+    if num_args == 1:
+
+        def iter_func(batches: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+            yield from func(pa.Table.from_batches(batches)).to_batches()
+
+        yield iter_func
+    else:
+
+        def iter_keys_func(
+            keys: Tuple[pa.Scalar, ...], batches: Iterator[pa.RecordBatch]
+        ) -> Iterator[pa.RecordBatch]:
+            yield from func(keys, pa.Table.from_batches(batches)).to_batches()
+
+        yield iter_keys_func
+
+
 @unittest.skipIf(
     not have_pyarrow,
     pyarrow_requirement_message,  # type: ignore[arg-type]
@@ -58,8 +79,9 @@ class ApplyInArrowTestsMixin:
         grouped_df = df.groupBy((col("id") / 4).cast("int"))
         expected = df.collect()
 
-        actual = grouped_df.applyInArrow(func, "id long, value long").collect()
-        self.assertEqual(actual, expected)
+        for func_variation in function_variations(func):
+            actual = grouped_df.applyInArrow(func_variation, "id long, value 
long").collect()
+            self.assertEqual(actual, expected)
 
     def test_apply_in_arrow_with_key(self):
         def func(key, group):
@@ -76,8 +98,9 @@ class ApplyInArrowTestsMixin:
         grouped_df = df.groupBy((col("id") / 4).cast("int"))
         expected = df.collect()
 
-        actual2 = grouped_df.applyInArrow(func, "id long, value 
long").collect()
-        self.assertEqual(actual2, expected)
+        for func_variation in function_variations(func):
+            actual2 = grouped_df.applyInArrow(func_variation, "id long, value 
long").collect()
+            self.assertEqual(actual2, expected)
 
     def test_apply_in_arrow_empty_groupby(self):
         df = self.data
@@ -88,20 +111,23 @@ class ApplyInArrowTestsMixin:
                 1, "v", pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, 
ddof=1))
             )
 
-        # casting doubles to floats to get rid of numerical precision issues
-        # when comparing Arrow and Spark values
-        actual = (
-            df.groupby()
-            .applyInArrow(normalize, "id long, v double")
-            .withColumn("v", col("v").cast("float"))
-            .sort("id", "v")
-        )
-        windowSpec = Window.partitionBy()
-        expected = df.withColumn(
-            "v",
-            ((df.v - mean(df.v).over(windowSpec)) / 
stddev(df.v).over(windowSpec)).cast("float"),
-        )
-        self.assertEqual(actual.collect(), expected.collect())
+        for func_variation in function_variations(normalize):
+            # casting doubles to floats to get rid of numerical precision 
issues
+            # when comparing Arrow and Spark values
+            actual = (
+                df.groupby()
+                .applyInArrow(func_variation, "id long, v double")
+                .withColumn("v", col("v").cast("float"))
+                .sort("id", "v")
+            )
+            windowSpec = Window.partitionBy()
+            expected = df.withColumn(
+                "v",
+                ((df.v - mean(df.v).over(windowSpec)) / 
stddev(df.v).over(windowSpec)).cast(
+                    "float"
+                ),
+            )
+            self.assertEqual(actual.collect(), expected.collect())
 
     def test_apply_in_arrow_not_returning_arrow_table(self):
         df = self.data
@@ -109,6 +135,11 @@ class ApplyInArrowTestsMixin:
         def stats(key, _):
             return key
 
+        def stats_iter(
+            key: Tuple[pa.Scalar, ...], _: Iterator[pa.RecordBatch]
+        ) -> Iterator[pa.RecordBatch]:
+            yield key
+
         with self.quiet():
             with self.assertRaisesRegex(
                 PythonException,
@@ -116,6 +147,13 @@ class ApplyInArrowTestsMixin:
             ):
                 df.groupby("id").applyInArrow(stats, schema="id long, m 
double").collect()
 
+            with self.assertRaisesRegex(
+                PythonException,
+                "Return type of the user-defined function should be 
pyarrow.RecordBatch, but is "
+                + "tuple",
+            ):
+                df.groupby("id").applyInArrow(stats_iter, schema="id long, m 
double").collect()
+
     def test_apply_in_arrow_returning_wrong_types(self):
         df = self.data
 
@@ -131,11 +169,12 @@ class ApplyInArrowTestsMixin:
         ]:
             with self.subTest(schema=schema):
                 with self.quiet():
-                    with self.assertRaisesRegex(
-                        PythonException,
-                        f"Columns do not match in their data type: {expected}",
-                    ):
-                        df.groupby("id").applyInArrow(lambda table: table, 
schema=schema).collect()
+                    for func_variation in function_variations(lambda table: 
table):
+                        with self.assertRaisesRegex(
+                            PythonException,
+                            f"Columns do not match in their data type: 
{expected}",
+                        ):
+                            df.groupby("id").applyInArrow(func_variation, 
schema=schema).collect()
 
     def test_apply_in_arrow_returning_wrong_types_positional_assignment(self):
         df = self.data
@@ -155,13 +194,14 @@ class ApplyInArrowTestsMixin:
                     
{"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}
                 ):
                     with self.quiet():
-                        with self.assertRaisesRegex(
-                            PythonException,
-                            f"Columns do not match in their data type: 
{expected}",
-                        ):
-                            df.groupby("id").applyInArrow(
-                                lambda table: table, schema=schema
-                            ).collect()
+                        for func_variation in function_variations(lambda 
table: table):
+                            with self.assertRaisesRegex(
+                                PythonException,
+                                f"Columns do not match in their data type: 
{expected}",
+                            ):
+                                df.groupby("id").applyInArrow(
+                                    func_variation, schema=schema
+                                ).collect()
 
     def test_apply_in_arrow_returning_wrong_column_names(self):
         df = self.data
@@ -177,13 +217,16 @@ class ApplyInArrowTestsMixin:
             )
 
         with self.quiet():
-            with self.assertRaisesRegex(
-                PythonException,
-                "Column names of the returned pyarrow.Table do not match 
specified schema. "
-                "Missing: m. Unexpected: v, v2.\n",
-            ):
-                # stats returns three columns while here we set schema with 
two columns
-                df.groupby("id").applyInArrow(stats, schema="id long, m 
double").collect()
+            for func_variation in function_variations(stats):
+                with self.assertRaisesRegex(
+                    PythonException,
+                    "Column names of the returned pyarrow.Table do not match 
specified schema. "
+                    "Missing: m. Unexpected: v, v2.\n",
+                ):
+                    # stats returns three columns while here we set schema 
with two columns
+                    df.groupby("id").applyInArrow(
+                        func_variation, schema="id long, m double"
+                    ).collect()
 
     def test_apply_in_arrow_returning_empty_dataframe(self):
         df = self.data
@@ -197,9 +240,12 @@ class ApplyInArrowTestsMixin:
                 )
 
         schema = "id long, m double"
-        actual = df.groupby("id").applyInArrow(odd_means, 
schema=schema).sort("id").collect()
-        expected = [Row(id=id, m=24.5) for id in range(1, 10, 2)]
-        self.assertEqual(expected, actual)
+        for func_variation in function_variations(odd_means):
+            actual = (
+                df.groupby("id").applyInArrow(func_variation, 
schema=schema).sort("id").collect()
+            )
+            expected = [Row(id=id, m=24.5) for id in range(1, 10, 2)]
+            self.assertEqual(expected, actual)
 
     def 
test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self):
         df = self.data
@@ -230,14 +276,15 @@ class ApplyInArrowTestsMixin:
         def change_col_order(table):
             return table.append_column("u", pc.multiply(table.column("v"), 3))
 
-        # The result should assign columns by name from the table
-        result = (
-            grouped_df.applyInArrow(change_col_order, "id long, u long, v int")
-            .sort("id", "v")
-            .select("id", "u", "v")
-            .collect()
-        )
-        self.assertEqual(expected, result)
+        for func_variation in function_variations(change_col_order):
+            # The result should assign columns by name from the table
+            result = (
+                grouped_df.applyInArrow(func_variation, "id long, u long, v 
int")
+                .sort("id", "v")
+                .select("id", "u", "v")
+                .collect()
+            )
+            self.assertEqual(expected, result)
 
     def test_positional_assignment_conf(self):
         with self.sql_conf(
@@ -248,12 +295,52 @@ class ApplyInArrowTestsMixin:
                 return pa.Table.from_pydict({"x": ["hi"], "y": [1]})
 
             df = self.data
-            result = (
-                df.groupBy("id").applyInArrow(foo, "a string, b 
long").select("a", "b").collect()
-            )
-            for r in result:
-                self.assertEqual(r.a, "hi")
-                self.assertEqual(r.b, 1)
+            for func_variation in function_variations(foo):
+                result = (
+                    df.groupBy("id")
+                    .applyInArrow(func_variation, "a string, b long")
+                    .select("a", "b")
+                    .collect()
+                )
+                for r in result:
+                    self.assertEqual(r.a, "hi")
+                    self.assertEqual(r.b, 1)
+
+    def test_apply_in_arrow_batching(self):
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
2}):
+
+            def func(group: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+                assert isinstance(group, Iterator)
+                batches = list(group)
+                assert len(batches) == 2
+                for batch in batches:
+                    assert isinstance(batch, pa.RecordBatch)
+                    assert batch.schema.names == ["id", "value"]
+                yield from batches
+
+            df = self.spark.range(12).withColumn("value", col("id") * 10)
+            grouped_df = df.groupBy((col("id") / 4).cast("int"))
+
+            actual = grouped_df.applyInArrow(func, "id long, value 
long").collect()
+            self.assertEqual(actual, df.collect())
+
+    def test_apply_in_arrow_partial_iteration(self):
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
2}):
+
+            def func(group: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+                first = next(group)
+                yield pa.RecordBatch.from_pylist(
+                    [{"value": r.as_py() % 4} for r in first.column(0)]
+                )
+
+            df = self.spark.range(20)
+            grouped_df = df.groupBy((col("id") % 4).cast("int"))
+
+            # Should get two records for each group
+            expected = [Row(value=x) for x in [0, 0, 1, 1, 2, 2, 3, 3]]
+
+            actual = grouped_df.applyInArrow(func, "value long").collect()
+            self.assertEqual(actual, expected)
 
     def test_self_join(self):
         df = self.spark.createDataFrame([(1, 1)], ("k", "v"))
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
index e61fdb81c028..1b6c5135f701 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_typehints.py
@@ -28,9 +28,10 @@ from pyspark.testing.utils import (
     numpy_requirement_message,
 )
 from pyspark.testing.sqlutils import ReusedSQLTestCase
-from pyspark.sql.pandas.typehints import infer_eval_type
+from pyspark.sql.pandas.typehints import infer_eval_type, 
infer_group_arrow_eval_type
 from pyspark.sql.pandas.functions import arrow_udf, ArrowUDFType
 from pyspark.sql import Row
+from pyspark.util import PythonEvalType
 
 if have_pyarrow:
     import pyarrow as pa
@@ -172,6 +173,28 @@ class ArrowUDFTypeHintsTests(ReusedSQLTestCase):
             ArrowUDFType.GROUPED_AGG,
         )
 
+    def test_type_annotation_group_map(self):
+        def func(col: pa.Table) -> pa.Table:
+            pass
+
+        self.assertEqual(
+            infer_group_arrow_eval_type(signature(func), get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        )
+
+        def func(col: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
+            pass
+
+        self.assertEqual(
+            infer_group_arrow_eval_type(signature(func), get_type_hints(func)),
+            PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
+        )
+
+        def func(col: Iterator[pa.Array]) -> Iterator[pa.Array]:
+            pass
+
+        self.assertEqual(infer_group_arrow_eval_type(signature(func), 
get_type_hints(func)), None)
+
     def test_type_annotation_negative(self):
         def func(col: str) -> pa.Array:
             pass
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 8497cabde5e6..f94fc73b6435 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -60,6 +60,7 @@ if typing.TYPE_CHECKING:
         ArrowMapIterUDFType,
         PandasGroupedMapUDFWithStateType,
         ArrowGroupedMapUDFType,
+        ArrowGroupedMapIterUDFType,
         ArrowCogroupedMapUDFType,
         PandasGroupedMapUDFTransformWithStateType,
         PandasGroupedMapUDFTransformWithStateInitStateType,
@@ -650,6 +651,7 @@ class PythonEvalType:
     SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF: 
"GroupedMapUDFTransformWithStateInitStateType" = (  # noqa: E501
         214
     )
+    SQL_GROUPED_MAP_ARROW_ITER_UDF: "ArrowGroupedMapIterUDFType" = 215
 
     # Arrow UDFs
     SQL_SCALAR_ARROW_UDF: "ArrowScalarUDFType" = 250
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 25eaf2624391..a15d59f04e1e 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -18,6 +18,7 @@
 """
 Worker that receives input from Piped RDD.
 """
+import itertools
 import os
 import sys
 import dataclasses
@@ -59,6 +60,7 @@ from pyspark.sql.pandas.serializers import (
     CogroupPandasUDFSerializer,
     ArrowStreamUDFSerializer,
     ApplyInPandasWithStateSerializer,
+    GroupPandasUDFSerializer,
     TransformWithStateInPandasSerializer,
     TransformWithStateInPandasInitStateSerializer,
     TransformWithStateInPySparkRowSerializer,
@@ -539,7 +541,7 @@ def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, 
runner_conf):
             key = tuple(c[0] for c in key_table.columns)
             result = f(key, left_value_table, right_value_table)
 
-        verify_arrow_result(result, _assign_cols_by_name, 
expected_cols_and_types)
+        verify_arrow_table(result, _assign_cols_by_name, 
expected_cols_and_types)
 
         return result.to_batches()
 
@@ -572,25 +574,14 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, 
argspec, runner_conf):
     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
arrow_return_type)]
 
 
-def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types):
-    import pyarrow as pa
-
-    if not isinstance(table, pa.Table):
-        raise PySparkTypeError(
-            errorClass="UDF_RETURN_TYPE",
-            messageParameters={
-                "expected": "pyarrow.Table",
-                "actual": type(table).__name__,
-            },
-        )
-
+def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
     # the types of the fields have to be identical to return type
     # an empty table can have no columns; if there are columns, they have to 
match
-    if table.num_columns != 0 or table.num_rows != 0:
+    if result.num_columns != 0 or result.num_rows != 0:
         # columns are either mapped by name or position
         if assign_cols_by_name:
             actual_cols_and_types = {
-                name: dataType for name, dataType in zip(table.schema.names, 
table.schema.types)
+                name: dataType for name, dataType in zip(result.schema.names, 
result.schema.types)
             }
             missing = sorted(
                 
list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys()))
@@ -617,7 +608,7 @@ def verify_arrow_result(table, assign_cols_by_name, 
expected_cols_and_types):
             ]
         else:
             actual_cols_and_types = [
-                (name, dataType) for name, dataType in zip(table.schema.names, 
table.schema.types)
+                (name, dataType) for name, dataType in 
zip(result.schema.names, result.schema.types)
             ]
             column_types = [
                 (expected_name, expected_type, actual_type)
@@ -644,7 +635,39 @@ def verify_arrow_result(table, assign_cols_by_name, 
expected_cols_and_types):
             )
 
 
+def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types):
+    import pyarrow as pa
+
+    if not isinstance(table, pa.Table):
+        raise PySparkTypeError(
+            errorClass="UDF_RETURN_TYPE",
+            messageParameters={
+                "expected": "pyarrow.Table",
+                "actual": type(table).__name__,
+            },
+        )
+
+    verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types)
+
+
+def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types):
+    import pyarrow as pa
+
+    if not isinstance(batch, pa.RecordBatch):
+        raise PySparkTypeError(
+            errorClass="UDF_RETURN_TYPE",
+            messageParameters={
+                "expected": "pyarrow.RecordBatch",
+                "actual": type(batch).__name__,
+            },
+        )
+
+    verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types)
+
+
 def wrap_grouped_map_arrow_udf(f, return_type, argspec, runner_conf):
+    import pyarrow as pa
+
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
     if _assign_cols_by_name:
@@ -656,16 +679,46 @@ def wrap_grouped_map_arrow_udf(f, return_type, argspec, 
runner_conf):
             (col.name, to_arrow_type(col.dataType)) for col in 
return_type.fields
         ]
 
-    def wrapped(key_table, value_table):
+    def wrapped(key_batch, value_batches):
+        value_table = pa.Table.from_batches(value_batches)
         if len(argspec.args) == 1:
             result = f(value_table)
         elif len(argspec.args) == 2:
-            key = tuple(c[0] for c in key_table.columns)
+            key = tuple(c[0] for c in key_batch.columns)
             result = f(key, value_table)
 
-        verify_arrow_result(result, _assign_cols_by_name, 
expected_cols_and_types)
+        verify_arrow_table(result, _assign_cols_by_name, 
expected_cols_and_types)
 
-        return result.to_batches()
+        yield from result.to_batches()
+
+    arrow_return_type = to_arrow_type(return_type, 
use_large_var_types(runner_conf))
+    return lambda k, v: (wrapped(k, v), arrow_return_type)
+
+
+def wrap_grouped_map_arrow_iter_udf(f, return_type, argspec, runner_conf):
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
+    if _assign_cols_by_name:
+        expected_cols_and_types = {
+            col.name: to_arrow_type(col.dataType) for col in return_type.fields
+        }
+    else:
+        expected_cols_and_types = [
+            (col.name, to_arrow_type(col.dataType)) for col in 
return_type.fields
+        ]
+
+    def wrapped(key_batch, value_batches):
+        if len(argspec.args) == 1:
+            result = f(value_batches)
+        elif len(argspec.args) == 2:
+            key = tuple(c[0] for c in key_batch.columns)
+            result = f(key, value_batches)
+
+        def verify_element(batch):
+            verify_arrow_batch(batch, _assign_cols_by_name, 
expected_cols_and_types)
+            return batch
+
+        yield from map(verify_element, result)
 
     arrow_return_type = to_arrow_type(return_type, 
use_large_var_types(runner_conf))
     return lambda k, v: (wrapped(k, v), arrow_return_type)
@@ -1210,6 +1263,11 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, 
argspec, runner_conf)
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF:
+        argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
+        return args_offsets, wrap_grouped_map_arrow_iter_udf(
+            func, return_type, argspec, runner_conf
+        )
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         return args_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
@@ -2505,6 +2563,7 @@ def read_udfs(pickleSer, infile, eval_type):
         PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
         PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+        PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
         PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
@@ -2547,7 +2606,10 @@ def read_udfs(pickleSer, infile, eval_type):
         )
         _assign_cols_by_name = assign_cols_by_name(runner_conf)
 
-        if eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+        if (
+            eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+            or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+        ):
             ser = GroupArrowUDFSerializer(_assign_cols_by_name)
         elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
             ser = GroupPandasUDFSerializer(
@@ -2931,7 +2993,10 @@ def read_udfs(pickleSer, infile, eval_type):
                 # mode == PROCESS_TIMER or mode == COMPLETE
                 return f(stateful_processor_api_client, mode, None, iter([]))
 
-    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+    elif (
+        eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+        or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+    ):
         import pyarrow as pa
 
         # We assume there is only one UDF here because grouped map doesn't
@@ -2951,13 +3016,18 @@ def read_udfs(pickleSer, infile, eval_type):
                 names=[batch.schema.names[o] for o in offsets],
             )
 
-        def table_from_batches(batches, offsets):
-            return pa.Table.from_batches([batch_from_offset(batch, offsets) 
for batch in batches])
-
         def mapper(a):
-            keys = table_from_batches(a, parsed_offsets[0][0])
-            vals = table_from_batches(a, parsed_offsets[0][1])
-            return f(keys, vals)
+            batch_iter = iter(a)
+            # Need to materialize the first batch to get the keys
+            first_batch = next(batch_iter)
+
+            keys = batch_from_offset(first_batch, parsed_offsets[0][0])
+            value_batches = (
+                batch_from_offset(b, parsed_offsets[0][1])
+                for b in itertools.chain((first_batch,), batch_iter)
+            )
+
+            return f(keys, value_batches)
 
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         # We assume there is only one UDF here because grouped map doesn't
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3a1b3cebe893..37c6d529b82f 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -652,7 +652,8 @@ class SparkConnectPlanner(
           case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
             group.flatMapGroupsInPandas(Column(pythonUdf)).logicalPlan
 
-          case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF =>
+          case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF |
+              PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF =>
             group.flatMapGroupsInArrow(Column(pythonUdf)).logicalPlan
 
           case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF |
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
index 0b1da71be9ed..674c206c96a9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/RelationalGroupedDataset.scala
@@ -311,7 +311,8 @@ class RelationalGroupedDataset protected[sql](
    */
   private[sql] def flatMapGroupsInArrow(column: Column): DataFrame = {
     val expr = column.expr.asInstanceOf[PythonUDF]
-    require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
+    require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF ||
+      expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
       "Must pass a grouped map arrow udf")
     require(expr.dataType.isInstanceOf[StructType],
       s"The returnType of the udf must be a ${StructType.simpleString}")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
index 6569b29f3954..2885b127cc82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
@@ -27,19 +26,16 @@ import org.apache.spark.sql.types.{StructField, StructType}
 /**
  * Physical node for 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInArrow]]
  *
- * Rows in each group are passed to the Python worker as an Arrow record batch.
- * The Python worker turns the record batch to a `pyarrow.Table`, invokes the
- * user-defined function, and passes the resulting `pyarrow.Table`
- * as an Arrow record batch. Finally, each record batch is turned to
+ * Rows in each group are passed to the Python worker as an iterator of Arrow 
record batches.
+ * The Python worker passes the record batches either as a materialized 
`pyarrow.Table` or
+ * an iterator of pyarrow.RecordBatch, depending on the eval type of the 
user-defined function.
+ * The Python worker returns the resulting record batches which are turned 
into an
  * Iterator[InternalRow] using ColumnarBatch.
  *
  * Note on memory usage:
- * Both the Python worker and the Java executor need to have enough memory to
- * hold the largest group. The memory on the Java side is used to construct the
- * record batch (off heap memory). The memory on the Python side is used for
- * holding the `pyarrow.Table`. It's possible to further split one group into
- * multiple record batches to reduce the memory footprint on the Java side, 
this
- * is left as future work.
+ * When using the `pyarrow.Table` API, the entire group is materialized in 
memory in the Python
+ * worker, and the entire result for a group must also be fully materialized. 
The iterator of
+ * record batches API can be used to avoid this limitation on the Python side.
  */
 case class FlatMapGroupsInArrowExec(
     groupingAttributes: Seq[Attribute],
@@ -48,7 +44,9 @@ case class FlatMapGroupsInArrowExec(
     child: SparkPlan)
   extends FlatMapGroupsInBatchExec {
 
-  protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+  protected val pythonEvalType: Int = {
+    func.asInstanceOf[PythonUDF].evalType
+  }
 
   override protected def groupedData(iter: Iterator[InternalRow], attrs: 
Seq[Attribute]):
       Iterator[Iterator[InternalRow]] =


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to