This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c0fc708b88f [SPARK-51891][SS] Squeeze the protocol of ListState GET / 
PUT / APPENDLIST for transformWithState in PySpark
2c0fc708b88f is described below

commit 2c0fc708b88fef383b910c133aec97e1ec8a6a38
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Fri Apr 25 10:08:27 2025 +0900

    [SPARK-51891][SS] Squeeze the protocol of ListState GET / PUT / APPENDLIST 
for transformWithState in PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to squeeze the protocol of ListState GET / PUT / 
APPENDLIST for transformWithState in PySpark, which will help a lot on dealing 
with small list on ListState.
    
    Here are the changes:
    
    * ListState.get() no longer requires additional request to notice there is 
no further data to read.
      * We inline the data into proto message, to ease of determine whether the 
iterator has fully consumed or not.
    * ListState.put() / ListState.appendList() do not require additional 
request to send the data separately.
      * We inline the data into propo message if the length of list we pass is 
small enough (now it's "magically" set to 100 elements - need to look further)
      * If the length of list is over 100, we fall back to "old" Arrow send 
(rather than custom protocol). This is because of the fact pickled Python Row 
contains the schema information as string, which is larger than we anticipated. 
So in some point, Arrow would be more efficient.
    
    NOTE: 100 is a sort of "magic number", and we will need to improve this 
with more benchmarking.
    
    ### Why are the changes needed?
    
    To optimize further on ListState operations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50689 from HeartSaVioR/SPARK-51891.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/streaming/list_state_client.py  |  81 +++++++++--
 .../sql/streaming/proto/StateMessage_pb2.py        | 156 +++++++++++----------
 .../sql/streaming/proto/StateMessage_pb2.pyi       |  72 ++++++++++
 .../sql/streaming/stateful_processor_api_client.py |  12 ++
 .../helper/helper_pandas_transform_with_state.py   | 131 +++++++++++++++++
 .../pandas/test_pandas_transform_with_state.py     |  74 ++++++++++
 .../sql/execution/streaming/StateMessage.proto     |  11 ++
 .../TransformWithStateInPySparkStateServer.scala   |  69 ++++++++-
 ...ansformWithStateInPySparkStateServerSuite.scala |  62 +++++---
 9 files changed, 553 insertions(+), 115 deletions(-)

diff --git a/python/pyspark/sql/streaming/list_state_client.py 
b/python/pyspark/sql/streaming/list_state_client.py
index 66f2640c935e..08b672e86e08 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -37,7 +37,7 @@ class ListStateClient:
             self.schema = schema
         # A dictionary to store the mapping between list state name and a 
tuple of data batch
         # and the index of the last row that was read.
-        self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}
+        self.data_batch_dict: Dict[str, Tuple[Any, int, bool]] = {}
 
     def exists(self, state_name: str) -> bool:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -61,12 +61,12 @@ class ListStateClient:
                 f"Error checking value state exists: " f"{response_message[1]}"
             )
 
-    def get(self, state_name: str, iterator_id: str) -> Tuple:
+    def get(self, state_name: str, iterator_id: str) -> Tuple[Tuple, bool]:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if iterator_id in self.data_batch_dict:
             # If the state is already in the dictionary, return the next row.
-            data_batch, index = self.data_batch_dict[iterator_id]
+            data_batch, index, require_next_fetch = 
self.data_batch_dict[iterator_id]
         else:
             # If the state is not in the dictionary, fetch the state from the 
server.
             get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -79,23 +79,35 @@ class ListStateClient:
             message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
 
             
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
-            response_message = 
self._stateful_processor_api_client._receive_proto_message()
+            response_message = (
+                
self._stateful_processor_api_client._receive_proto_message_with_list_get()
+            )
             status = response_message[0]
             if status == 0:
-                data_batch = 
self._stateful_processor_api_client._read_list_state()
+                data_batch = list(
+                    map(
+                        lambda x: 
self._stateful_processor_api_client._deserialize_from_bytes(x),
+                        response_message[2],
+                    )
+                )
+                require_next_fetch = response_message[3]
                 index = 0
             else:
                 raise StopIteration()
 
+        is_last_row = False
         new_index = index + 1
         if new_index < len(data_batch):
             # Update the index in the dictionary.
-            self.data_batch_dict[iterator_id] = (data_batch, new_index)
+            self.data_batch_dict[iterator_id] = (data_batch, new_index, 
require_next_fetch)
         else:
             # If the index is at the end of the data batch, remove the state 
from the dictionary.
             self.data_batch_dict.pop(iterator_id, None)
+            is_last_row = True
+
+        is_last_row_from_iterator = is_last_row and not require_next_fetch
         row = data_batch[index]
-        return tuple(row)
+        return (tuple(row), is_last_row_from_iterator)
 
     def append_value(self, state_name: str, value: Tuple) -> None:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -118,7 +130,24 @@ class ListStateClient:
     def append_list(self, state_name: str, values: List[Tuple]) -> None:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
-        append_list_call = stateMessage.AppendList()
+        send_data_via_arrow = False
+
+        # To workaround mypy type assignment check.
+        values_as_bytes: Any = []
+        if len(values) == 100:
+            # TODO(SPARK-51907): Let's update this to be either flexible or 
more reasonable default
+            #  value backed by various benchmarks.
+            # Arrow codepath
+            send_data_via_arrow = True
+        else:
+            values_as_bytes = map(
+                lambda x: 
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+                values,
+            )
+
+        append_list_call = stateMessage.AppendList(
+            value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+        )
         list_state_call = stateMessage.ListStateCall(
             stateName=state_name, appendList=append_list_call
         )
@@ -127,7 +156,9 @@ class ListStateClient:
 
         
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
 
-        self._stateful_processor_api_client._send_list_state(self.schema, 
values)
+        if send_data_via_arrow:
+            self._stateful_processor_api_client._send_arrow_state(self.schema, 
values)
+
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
@@ -137,14 +168,32 @@ class ListStateClient:
     def put(self, state_name: str, values: List[Tuple]) -> None:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
-        put_call = stateMessage.ListStatePut()
+        send_data_via_arrow = False
+        # To workaround mypy type assignment check.
+        values_as_bytes: Any = []
+        if len(values) == 100:
+            # TODO(SPARK-51907): Let's update this to be either flexible or 
more reasonable default
+            #  value backed by various benchmarks.
+            send_data_via_arrow = True
+        else:
+            values_as_bytes = map(
+                lambda x: 
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+                values,
+            )
+
+        put_call = stateMessage.ListStatePut(
+            value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+        )
+
         list_state_call = stateMessage.ListStateCall(stateName=state_name, 
listStatePut=put_call)
         state_variable_request = 
stateMessage.StateVariableRequest(listStateCall=list_state_call)
         message = 
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
 
         
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
 
-        self._stateful_processor_api_client._send_list_state(self.schema, 
values)
+        if send_data_via_arrow:
+            self._stateful_processor_api_client._send_arrow_state(self.schema, 
values)
+
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
@@ -174,9 +223,17 @@ class ListStateIterator:
         # Generate a unique identifier for the iterator to make sure iterators 
from the same
         # list state do not interfere with each other.
         self.iterator_id = str(uuid.uuid4())
+        self.iterator_fully_consumed = False
 
     def __iter__(self) -> Iterator[Tuple]:
         return self
 
     def __next__(self) -> Tuple:
-        return self.list_state_client.get(self.state_name, self.iterator_id)
+        if self.iterator_fully_consumed:
+            raise StopIteration()
+
+        row, is_last_row = self.list_state_client.get(self.state_name, 
self.iterator_id)
+        if is_last_row:
+            self.iterator_fully_consumed = True
+
+        return row
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py 
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
index 20af541f307c..094f1dd51c58 100644
--- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
@@ -40,7 +40,7 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01
 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03
 
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14st
 [...]
+    
b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01
 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03
 
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14st
 [...]
 )
 
 _globals = globals()
@@ -50,8 +50,8 @@ _builder.BuildTopDescriptorsAndMessages(
 )
 if not _descriptor._USE_C_DESCRIPTORS:
     DESCRIPTOR._loaded_options = None
-    _globals["_HANDLESTATE"]._serialized_start = 6408
-    _globals["_HANDLESTATE"]._serialized_end = 6518
+    _globals["_HANDLESTATE"]._serialized_start = 6695
+    _globals["_HANDLESTATE"]._serialized_end = 6805
     _globals["_STATEREQUEST"]._serialized_start = 112
     _globals["_STATEREQUEST"]._serialized_end = 756
     _globals["_STATERESPONSE"]._serialized_start = 758
@@ -60,78 +60,80 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 985
     _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_start = 987
     _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_end = 1109
-    _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1112
-    _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1784
-    _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1787
-    _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2128
-    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2131
-    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2390
-    _globals["_TIMERREQUEST"]._serialized_start = 2393
-    _globals["_TIMERREQUEST"]._serialized_end = 2650
-    _globals["_TIMERVALUEREQUEST"]._serialized_start = 2653
-    _globals["_TIMERVALUEREQUEST"]._serialized_end = 2899
-    _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2901
-    _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2967
-    _globals["_GETPROCESSINGTIME"]._serialized_start = 2969
-    _globals["_GETPROCESSINGTIME"]._serialized_end = 2988
-    _globals["_GETWATERMARK"]._serialized_start = 2990
-    _globals["_GETWATERMARK"]._serialized_end = 3004
-    _globals["_UTILSREQUEST"]._serialized_start = 3007
-    _globals["_UTILSREQUEST"]._serialized_end = 3146
-    _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3148
-    _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3191
-    _globals["_STATECALLCOMMAND"]._serialized_start = 3194
-    _globals["_STATECALLCOMMAND"]._serialized_end = 3393
-    _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3396
-    _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3691
-    _globals["_VALUESTATECALL"]._serialized_start = 3694
-    _globals["_VALUESTATECALL"]._serialized_end = 4096
-    _globals["_LISTSTATECALL"]._serialized_start = 4099
-    _globals["_LISTSTATECALL"]._serialized_end = 4706
-    _globals["_MAPSTATECALL"]._serialized_start = 4709
-    _globals["_MAPSTATECALL"]._serialized_end = 5543
-    _globals["_SETIMPLICITKEY"]._serialized_start = 5545
-    _globals["_SETIMPLICITKEY"]._serialized_end = 5579
-    _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5581
-    _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5600
-    _globals["_EXISTS"]._serialized_start = 5602
-    _globals["_EXISTS"]._serialized_end = 5610
-    _globals["_GET"]._serialized_start = 5612
-    _globals["_GET"]._serialized_end = 5617
-    _globals["_REGISTERTIMER"]._serialized_start = 5619
-    _globals["_REGISTERTIMER"]._serialized_end = 5680
-    _globals["_DELETETIMER"]._serialized_start = 5682
-    _globals["_DELETETIMER"]._serialized_end = 5741
-    _globals["_LISTTIMERS"]._serialized_start = 5743
-    _globals["_LISTTIMERS"]._serialized_end = 5787
-    _globals["_VALUESTATEUPDATE"]._serialized_start = 5789
-    _globals["_VALUESTATEUPDATE"]._serialized_end = 5829
-    _globals["_CLEAR"]._serialized_start = 5831
-    _globals["_CLEAR"]._serialized_end = 5838
-    _globals["_LISTSTATEGET"]._serialized_start = 5840
-    _globals["_LISTSTATEGET"]._serialized_end = 5886
-    _globals["_LISTSTATEPUT"]._serialized_start = 5888
-    _globals["_LISTSTATEPUT"]._serialized_end = 5902
-    _globals["_APPENDVALUE"]._serialized_start = 5904
-    _globals["_APPENDVALUE"]._serialized_end = 5939
-    _globals["_APPENDLIST"]._serialized_start = 5941
-    _globals["_APPENDLIST"]._serialized_end = 5953
-    _globals["_GETVALUE"]._serialized_start = 5955
-    _globals["_GETVALUE"]._serialized_end = 5991
-    _globals["_CONTAINSKEY"]._serialized_start = 5993
-    _globals["_CONTAINSKEY"]._serialized_end = 6032
-    _globals["_UPDATEVALUE"]._serialized_start = 6034
-    _globals["_UPDATEVALUE"]._serialized_end = 6095
-    _globals["_ITERATOR"]._serialized_start = 6097
-    _globals["_ITERATOR"]._serialized_end = 6139
-    _globals["_KEYS"]._serialized_start = 6141
-    _globals["_KEYS"]._serialized_end = 6179
-    _globals["_VALUES"]._serialized_start = 6181
-    _globals["_VALUES"]._serialized_end = 6221
-    _globals["_REMOVEKEY"]._serialized_start = 6223
-    _globals["_REMOVEKEY"]._serialized_end = 6260
-    _globals["_SETHANDLESTATE"]._serialized_start = 6262
-    _globals["_SETHANDLESTATE"]._serialized_end = 6361
-    _globals["_TTLCONFIG"]._serialized_start = 6363
-    _globals["_TTLCONFIG"]._serialized_end = 6406
+    _globals["_STATERESPONSEWITHLISTGET"]._serialized_start = 1112
+    _globals["_STATERESPONSEWITHLISTGET"]._serialized_end = 1272
+    _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1275
+    _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1947
+    _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1950
+    _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2291
+    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2294
+    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2553
+    _globals["_TIMERREQUEST"]._serialized_start = 2556
+    _globals["_TIMERREQUEST"]._serialized_end = 2813
+    _globals["_TIMERVALUEREQUEST"]._serialized_start = 2816
+    _globals["_TIMERVALUEREQUEST"]._serialized_end = 3062
+    _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 3064
+    _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 3130
+    _globals["_GETPROCESSINGTIME"]._serialized_start = 3132
+    _globals["_GETPROCESSINGTIME"]._serialized_end = 3151
+    _globals["_GETWATERMARK"]._serialized_start = 3153
+    _globals["_GETWATERMARK"]._serialized_end = 3167
+    _globals["_UTILSREQUEST"]._serialized_start = 3170
+    _globals["_UTILSREQUEST"]._serialized_end = 3309
+    _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3311
+    _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3354
+    _globals["_STATECALLCOMMAND"]._serialized_start = 3357
+    _globals["_STATECALLCOMMAND"]._serialized_end = 3556
+    _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3559
+    _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3854
+    _globals["_VALUESTATECALL"]._serialized_start = 3857
+    _globals["_VALUESTATECALL"]._serialized_end = 4259
+    _globals["_LISTSTATECALL"]._serialized_start = 4262
+    _globals["_LISTSTATECALL"]._serialized_end = 4869
+    _globals["_MAPSTATECALL"]._serialized_start = 4872
+    _globals["_MAPSTATECALL"]._serialized_end = 5706
+    _globals["_SETIMPLICITKEY"]._serialized_start = 5708
+    _globals["_SETIMPLICITKEY"]._serialized_end = 5742
+    _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5744
+    _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5763
+    _globals["_EXISTS"]._serialized_start = 5765
+    _globals["_EXISTS"]._serialized_end = 5773
+    _globals["_GET"]._serialized_start = 5775
+    _globals["_GET"]._serialized_end = 5780
+    _globals["_REGISTERTIMER"]._serialized_start = 5782
+    _globals["_REGISTERTIMER"]._serialized_end = 5843
+    _globals["_DELETETIMER"]._serialized_start = 5845
+    _globals["_DELETETIMER"]._serialized_end = 5904
+    _globals["_LISTTIMERS"]._serialized_start = 5906
+    _globals["_LISTTIMERS"]._serialized_end = 5950
+    _globals["_VALUESTATEUPDATE"]._serialized_start = 5952
+    _globals["_VALUESTATEUPDATE"]._serialized_end = 5992
+    _globals["_CLEAR"]._serialized_start = 5994
+    _globals["_CLEAR"]._serialized_end = 6001
+    _globals["_LISTSTATEGET"]._serialized_start = 6003
+    _globals["_LISTSTATEGET"]._serialized_end = 6049
+    _globals["_LISTSTATEPUT"]._serialized_start = 6051
+    _globals["_LISTSTATEPUT"]._serialized_end = 6127
+    _globals["_APPENDVALUE"]._serialized_start = 6129
+    _globals["_APPENDVALUE"]._serialized_end = 6164
+    _globals["_APPENDLIST"]._serialized_start = 6166
+    _globals["_APPENDLIST"]._serialized_end = 6240
+    _globals["_GETVALUE"]._serialized_start = 6242
+    _globals["_GETVALUE"]._serialized_end = 6278
+    _globals["_CONTAINSKEY"]._serialized_start = 6280
+    _globals["_CONTAINSKEY"]._serialized_end = 6319
+    _globals["_UPDATEVALUE"]._serialized_start = 6321
+    _globals["_UPDATEVALUE"]._serialized_end = 6382
+    _globals["_ITERATOR"]._serialized_start = 6384
+    _globals["_ITERATOR"]._serialized_end = 6426
+    _globals["_KEYS"]._serialized_start = 6428
+    _globals["_KEYS"]._serialized_end = 6466
+    _globals["_VALUES"]._serialized_start = 6468
+    _globals["_VALUES"]._serialized_end = 6508
+    _globals["_REMOVEKEY"]._serialized_start = 6510
+    _globals["_REMOVEKEY"]._serialized_end = 6547
+    _globals["_SETHANDLESTATE"]._serialized_start = 6549
+    _globals["_SETHANDLESTATE"]._serialized_end = 6648
+    _globals["_TTLCONFIG"]._serialized_start = 6650
+    _globals["_TTLCONFIG"]._serialized_end = 6693
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi 
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
index ac4b03b82034..aa86826862bb 100644
--- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
@@ -34,7 +34,9 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 """
 import builtins
+import collections.abc
 import google.protobuf.descriptor
+import google.protobuf.internal.containers
 import google.protobuf.internal.enum_type_wrapper
 import google.protobuf.message
 import sys
@@ -229,6 +231,44 @@ class 
StateResponseWithStringTypeVal(google.protobuf.message.Message):
 
 global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal
 
+class StateResponseWithListGet(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    STATUSCODE_FIELD_NUMBER: builtins.int
+    ERRORMESSAGE_FIELD_NUMBER: builtins.int
+    VALUE_FIELD_NUMBER: builtins.int
+    REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int
+    statusCode: builtins.int
+    errorMessage: builtins.str
+    @property
+    def value(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
 ...
+    requireNextFetch: builtins.bool
+    def __init__(
+        self,
+        *,
+        statusCode: builtins.int = ...,
+        errorMessage: builtins.str = ...,
+        value: collections.abc.Iterable[builtins.bytes] | None = ...,
+        requireNextFetch: builtins.bool = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "errorMessage",
+            b"errorMessage",
+            "requireNextFetch",
+            b"requireNextFetch",
+            "statusCode",
+            b"statusCode",
+            "value",
+            b"value",
+        ],
+    ) -> None: ...
+
+global___StateResponseWithListGet = StateResponseWithListGet
+
 class StatefulProcessorCall(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -1042,8 +1082,24 @@ global___ListStateGet = ListStateGet
 class ListStatePut(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    VALUE_FIELD_NUMBER: builtins.int
+    FETCHWITHARROW_FIELD_NUMBER: builtins.int
+    @property
+    def value(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
 ...
+    fetchWithArrow: builtins.bool
     def __init__(
         self,
+        *,
+        value: collections.abc.Iterable[builtins.bytes] | None = ...,
+        fetchWithArrow: builtins.bool = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "fetchWithArrow", b"fetchWithArrow", "value", b"value"
+        ],
     ) -> None: ...
 
 global___ListStatePut = ListStatePut
@@ -1065,8 +1121,24 @@ global___AppendValue = AppendValue
 class AppendList(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    VALUE_FIELD_NUMBER: builtins.int
+    FETCHWITHARROW_FIELD_NUMBER: builtins.int
+    @property
+    def value(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]:
 ...
+    fetchWithArrow: builtins.bool
     def __init__(
         self,
+        *,
+        value: collections.abc.Iterable[builtins.bytes] | None = ...,
+        fetchWithArrow: builtins.bool = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "fetchWithArrow", b"fetchWithArrow", "value", b"value"
+        ],
     ) -> None: ...
 
 global___AppendList = AppendList
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index e564d7186faa..18330c4096fa 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -425,6 +425,18 @@ class StatefulProcessorApiClient:
         message.ParseFromString(bytes)
         return message.statusCode, message.errorMessage, message.value
 
+    # The third return type is RepeatedScalarFieldContainer[bytes], which is 
protobuf's container
+    # type. We simplify it to Any here to avoid unnecessary complexity.
+    def _receive_proto_message_with_list_get(self) -> Tuple[int, str, Any, 
bool]:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        length = read_int(self.sockfile)
+        bytes = self.sockfile.read(length)
+        message = stateMessage.StateResponseWithListGet()
+        message.ParseFromString(bytes)
+
+        return message.statusCode, message.errorMessage, message.value, 
message.requireNextFetch
+
     def _receive_str(self) -> str:
         return self.utf8_deserializer.loads(self.sockfile)
 
diff --git 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index 53f6d77567ae..cc9f29609a5b 100644
--- 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -139,6 +139,14 @@ class ListStateProcessorFactory(StatefulProcessorFactory):
         return RowListStateProcessor()
 
 
+class ListStateLargeListProcessorFactory(StatefulProcessorFactory):
+    def pandas(self):
+        return PandasListStateLargeListProcessor()
+
+    def row(self):
+        return RowListStateLargeListProcessor()
+
+
 class ListStateLargeTTLProcessorFactory(StatefulProcessorFactory):
     def pandas(self):
         return PandasListStateLargeTTLProcessor()
@@ -922,6 +930,129 @@ class RowListStateProcessor(StatefulProcessor):
         pass
 
 
+class PandasListStateLargeListProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        list_state_schema = StructType([StructField("value", IntegerType(), 
True)])
+        value_state_schema = StructType([StructField("size", IntegerType(), 
True)])
+        self.list_state = handle.getListState("listState", list_state_schema)
+        self.list_size_state = handle.getValueState("listSizeState", 
value_state_schema)
+
+    def handleInputRows(self, key, rows, timerValues) -> 
Iterator[pd.DataFrame]:
+        elements_iter = self.list_state.get()
+        elements = list(elements_iter)
+
+        # Use the magic number 100 to test with both inline proto case and 
Arrow case.
+        # TODO(SPARK-51907): Let's update this to be either flexible or more 
reasonable default
+        #  value backed by various benchmarks.
+        # Put 90 elements per batch:
+        # 1st batch: read 0 element, and write 90 elements, read back 90 
elements
+        #   (both use inline proto)
+        # 2nd batch: read 90 elements, and write 90 elements, read back 180 
elements
+        #   (read uses both inline proto and Arrow, write uses Arrow)
+
+        if len(elements) == 0:
+            # should be the first batch
+            assert self.list_size_state.get() is None
+            new_elements = [(i,) for i in range(90)]
+            if key == ("0",):
+                self.list_state.put(new_elements)
+            else:
+                self.list_state.appendList(new_elements)
+            self.list_size_state.update((len(new_elements),))
+        else:
+            # check the elements
+            list_size = self.list_size_state.get()
+            assert list_size is not None
+            list_size = list_size[0]
+            assert list_size == len(
+                elements
+            ), f"list_size ({list_size}) != len(elements) ({len(elements)})"
+
+            expected_elements_in_state = [(i,) for i in range(list_size)]
+            assert elements == expected_elements_in_state
+
+            if key == ("0",):
+                # Use the operation `put`
+                new_elements = [(i,) for i in range(list_size + 90)]
+                self.list_state.put(new_elements)
+                final_size = len(new_elements)
+                self.list_size_state.update((final_size,))
+            else:
+                # Use the operation `appendList`
+                new_elements = [(i,) for i in range(list_size, list_size + 90)]
+                self.list_state.appendList(new_elements)
+                final_size = len(new_elements) + list_size
+                self.list_size_state.update((final_size,))
+
+        prev_elements = ",".join(map(lambda x: str(x[0]), elements))
+        updated_elements = ",".join(map(lambda x: str(x[0]), 
self.list_state.get()))
+
+        yield pd.DataFrame(
+            {"id": key, "prevElements": prev_elements, "updatedElements": 
updated_elements}
+        )
+
+
+class RowListStateLargeListProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        list_state_schema = StructType([StructField("value", IntegerType(), 
True)])
+        value_state_schema = StructType([StructField("size", IntegerType(), 
True)])
+        self.list_state = handle.getListState("listState", list_state_schema)
+        self.list_size_state = handle.getValueState("listSizeState", 
value_state_schema)
+
+    def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+        elements_iter = self.list_state.get()
+
+        elements = list(elements_iter)
+
+        # Use the magic number 100 to test with both inline proto case and 
Arrow case.
+        # TODO(SPARK-51907): Let's update this to be either flexible or more 
reasonable default
+        #  value backed by various benchmarks.
+        # Put 90 elements per batch:
+        # 1st batch: read 0 element, and write 90 elements, read back 90 
elements
+        #   (both use inline proto)
+        # 2nd batch: read 90 elements, and write 90 elements, read back 180 
elements
+        #   (read uses both inline proto and Arrow, write uses Arrow)
+
+        if len(elements) == 0:
+            # should be the first batch
+            assert self.list_size_state.get() is None
+            new_elements = [(i,) for i in range(90)]
+            if key == ("0",):
+                self.list_state.put(new_elements)
+            else:
+                self.list_state.appendList(new_elements)
+            self.list_size_state.update((len(new_elements),))
+        else:
+            # check the elements
+            list_size = self.list_size_state.get()
+            assert list_size is not None
+            list_size = list_size[0]
+            assert list_size == len(
+                elements
+            ), f"list_size ({list_size}) != len(elements) ({len(elements)})"
+
+            expected_elements_in_state = [(i,) for i in range(list_size)]
+            assert elements == expected_elements_in_state
+
+            if key == ("0",):
+                # Use the operation `put`
+                new_elements = [(i,) for i in range(list_size + 90)]
+                self.list_state.put(new_elements)
+                final_size = len(new_elements)
+                self.list_size_state.update((final_size,))
+            else:
+                # Use the operation `appendList`
+                new_elements = [(i,) for i in range(list_size, list_size + 90)]
+                self.list_state.appendList(new_elements)
+                final_size = len(new_elements) + list_size
+                self.list_size_state.update((final_size,))
+
+        prev_elements = ",".join(map(lambda x: str(x[0]), elements))
+        updated_elements = ",".join(map(lambda x: str(x[0]), 
self.list_state.get()))
+
+        yield Row(id=key[0], prevElements=prev_elements, 
updatedElements=updated_elements)
+
+
 class PandasListStateLargeTTLProcessor(PandasListStateProcessor):
     def init(self, handle: StatefulProcessorHandle) -> None:
         state_schema = StructType([StructField("temperature", IntegerType(), 
True)])
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 7f2469cd6b93..e36ae3a86a28 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -57,6 +57,7 @@ from 
pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
     TTLStatefulProcessorFactory,
     InvalidSimpleStatefulProcessorFactory,
     ListStateProcessorFactory,
+    ListStateLargeListProcessorFactory,
     ListStateLargeTTLProcessorFactory,
     MapStateProcessorFactory,
     MapStateLargeTTLProcessorFactory,
@@ -302,6 +303,79 @@ class TransformWithStateTestsMixin:
             ListStateProcessorFactory(), check_results, True, "processingTime"
         )
 
+    def test_transform_with_state_list_state_large_list(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                expected_prev_elements = ""
+                expected_updated_elements = ",".join(map(lambda x: str(x), 
range(90)))
+            else:
+                # batch_id == 1:
+                expected_prev_elements = ",".join(map(lambda x: str(x), 
range(90)))
+                expected_updated_elements = ",".join(map(lambda x: str(x), 
range(180)))
+
+            assert set(batch_df.sort("id").collect()) == {
+                Row(
+                    id="0",
+                    prevElements=expected_prev_elements,
+                    updatedElements=expected_updated_elements,
+                ),
+                Row(
+                    id="1",
+                    prevElements=expected_prev_elements,
+                    updatedElements=expected_updated_elements,
+                ),
+            }
+
+        input_path = tempfile.mkdtemp()
+        checkpoint_path = tempfile.mkdtemp()
+
+        self._prepare_test_resource1(input_path)
+        time.sleep(2)
+        self._prepare_test_resource2(input_path)
+
+        df = self._build_test_df(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("prevElements", StringType(), True),
+                StructField("updatedElements", StringType(), True),
+            ]
+        )
+
+        stateful_processor = 
self.get_processor(ListStateLargeListProcessorFactory())
+        if self.use_pandas():
+            tws_df = df.groupBy("id").transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="none",
+            )
+        else:
+            tws_df = df.groupBy("id").transformWithState(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="none",
+            )
+
+        q = (
+            tws_df.writeStream.queryName("this_query")
+            .option("checkpointLocation", checkpoint_path)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
     # test list state with ttl has the same behavior as list state when state 
doesn't expire.
     def test_transform_with_state_list_state_large_ttl(self):
         def check_results(batch_df, batch_id):
diff --git 
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
 
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
index 1374bd100a2f..ce83c285410b 100644
--- 
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
+++ 
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
@@ -48,6 +48,13 @@ message StateResponseWithStringTypeVal {
   string value = 3;
 }
 
+message StateResponseWithListGet {
+  int32 statusCode = 1;
+  string errorMessage = 2;
+  repeated bytes value = 3;
+  bool requireNextFetch = 4;
+}
+
 message StatefulProcessorCall {
   oneof method {
     SetHandleState setHandleState = 1;
@@ -197,6 +204,8 @@ message ListStateGet {
 }
 
 message ListStatePut {
+  repeated bytes value = 1;
+  bool fetchWithArrow = 2;
 }
 
 message AppendValue {
@@ -204,6 +213,8 @@ message AppendValue {
 }
 
 message AppendList {
+  repeated bytes value = 1;
+  bool fetchWithArrow = 2;
 }
 
 message GetValue {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index 541ccf14b06d..a2c4d130ef31 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -22,6 +22,7 @@ import java.nio.channels.{Channels, ServerSocketChannel}
 import java.time.Duration
 
 import scala.collection.mutable
+import scala.jdk.CollectionConverters._
 
 import com.google.protobuf.ByteString
 import org.apache.arrow.vector.VectorSchemaRoot
@@ -38,6 +39,7 @@ import 
org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleImplBase, 
StatefulProcessorHandleState, StateVariableType}
 import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, 
StateRequest, StateResponse, StateResponseWithLongTypeVal, 
StateResponseWithStringTypeVal, StateVariableRequest, TimerRequest, 
TimerStateCallCommand, TimerValueRequest, UtilsRequest, ValueStateCall}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponseWithListGet
 import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, 
ValueState}
 import org.apache.spark.sql.types.{BinaryType, LongType, StructField, 
StructType}
 import org.apache.spark.sql.util.ArrowUtils
@@ -481,7 +483,17 @@ class TransformWithStateInPySparkStateServer(
           sendResponse(2, s"state $stateName doesn't exist")
         }
       case ListStateCall.MethodCase.LISTSTATEPUT =>
-        val rows = deserializer.readListElements(inputStream, listStateInfo)
+        val rows = if (message.getListStatePut.getFetchWithArrow) {
+          deserializer.readArrowBatches(inputStream)
+        } else {
+          val elements = message.getListStatePut.getValueList.asScala
+          elements.map { e =>
+            PythonSQLUtils.toJVMRow(
+              e.toByteArray,
+              listStateInfo.schema,
+              listStateInfo.deserializer)
+          }
+        }
         listStateInfo.listState.put(rows.toArray)
         sendResponse(0)
       case ListStateCall.MethodCase.LISTSTATEGET =>
@@ -494,8 +506,7 @@ class TransformWithStateInPySparkStateServer(
         if (!iteratorOption.get.hasNext) {
           sendResponse(2, s"List state $stateName doesn't contain any value.")
         } else {
-          sendResponse(0)
-          sendIteratorForListState(iteratorOption.get)
+          sendResponseWithListGet(0, iter = iteratorOption.get)
         }
       case ListStateCall.MethodCase.APPENDVALUE =>
         val byteArray = message.getAppendValue.getValue.toByteArray
@@ -504,7 +515,17 @@ class TransformWithStateInPySparkStateServer(
         listStateInfo.listState.appendValue(newRow)
         sendResponse(0)
       case ListStateCall.MethodCase.APPENDLIST =>
-        val rows = deserializer.readListElements(inputStream, listStateInfo)
+        val rows = if (message.getAppendList.getFetchWithArrow) {
+          deserializer.readArrowBatches(inputStream)
+        } else {
+          val elements = message.getAppendList.getValueList.asScala
+          elements.map { e =>
+            PythonSQLUtils.toJVMRow(
+              e.toByteArray,
+              listStateInfo.schema,
+              listStateInfo.deserializer)
+          }
+        }
         listStateInfo.listState.appendList(rows.toArray)
         sendResponse(0)
       case ListStateCall.MethodCase.CLEAR =>
@@ -771,6 +792,46 @@ class TransformWithStateInPySparkStateServer(
       outputStream.write(responseMessageBytes)
     }
 
+    def sendResponseWithListGet(
+        status: Int,
+        errorMessage: String = null,
+        iter: Iterator[Row] = null): Unit = {
+      val responseMessageBuilder = StateResponseWithListGet.newBuilder()
+        .setStatusCode(status)
+      if (status != 0 && errorMessage != null) {
+        responseMessageBuilder.setErrorMessage(errorMessage)
+      }
+
+      if (status == 0) {
+        // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
+        // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This 
is to handle a case
+        // when there are multiple state variables, user tries to access a 
different state variable
+        // while the current state variable is not exhausted yet.
+        var rowCount = 0
+        while (iter.hasNext && rowCount < 
arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
+          val data = iter.next()
+
+          // Serialize the value row as a byte array
+          val valueBytes = PythonSQLUtils.toPyRow(data)
+
+          responseMessageBuilder.addValue(ByteString.copyFrom(valueBytes))
+
+          rowCount += 1
+        }
+
+        assert(rowCount > 0, s"rowCount should be greater than 0 when status 
code is 0, " +
+          s"iter.hasNext ${iter.hasNext}")
+
+        responseMessageBuilder.setRequireNextFetch(iter.hasNext)
+      }
+
+      val responseMessage = responseMessageBuilder.build()
+      val responseMessageBytes = responseMessage.toByteArray
+      val byteLength = responseMessageBytes.length
+      outputStream.writeInt(byteLength)
+      outputStream.write(responseMessageBytes)
+    }
+
     def sendIteratorAsArrowBatches[T](
         iter: Iterator[T],
         outputSchema: StructType,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
index 318fd6ce1c8f..fd750edf96ec 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
@@ -262,10 +262,10 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
-    // 1 for row, 1 for end of the data, 1 for proto response
-    verify(outputStream, times(3)).writeInt(any)
-    // 1 for sending an actual row, 1 for sending proto message
-    verify(outputStream, times(2)).write(any[Array[Byte]])
+    // 1 for proto response
+    verify(outputStream).writeInt(any)
+    // 1 for sending proto message
+    verify(outputStream).write(any[Array[Byte]])
   }
 
   test("list state get - iterator in map with multiple batches") {
@@ -282,20 +282,20 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
     // First call should send 2 records.
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
-    // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto 
response
-    verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
-    // maxRecordsPerBatch times for rows, 1 for sending proto message
-    verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
+    // 1 for proto response
+    verify(outputStream).writeInt(any)
+    // 1 for proto message
+    verify(outputStream).write(any[Array[Byte]])
     // Second call should send the remaining 2 records.
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
     // Since Mockito's verify counts the total number of calls, the expected 
number of writeInt
     // and write should be accumulated from the prior count; the number of 
calls are the same
     // with prior one.
-    // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto 
response
-    verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any)
-    // maxRecordsPerBatch times for rows, 1 for sending proto message
-    verify(outputStream, times(maxRecordsPerBatch * 2 + 
2)).write(any[Array[Byte]])
+    // 1 for proto response
+    verify(outputStream, times(2)).writeInt(any)
+    // 1 for sending proto message
+    verify(outputStream, times(2)).write(any[Array[Byte]])
   }
 
   test("list state get - iterator not in map") {
@@ -314,17 +314,26 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
 
     // Verify that only maxRecordsPerBatch (2) rows are written to the output 
stream while still
     // having 1 row left in the iterator.
-    // maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto 
response
-    verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
-    // 2 for rows, 1 for proto message
-    verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
+    // 1 for proto response
+    verify(outputStream, times(1)).writeInt(any)
+    // 1 for proto message
+    verify(outputStream, times(1)).write(any[Array[Byte]])
+  }
+
+  test("list state put - inlined data") {
+    val message = ListStateCall.newBuilder().setStateName(stateName)
+      
.setListStatePut(ListStatePut.newBuilder().setFetchWithArrow(false).build()).build()
+    stateServer.handleListStateRequest(message)
+    // Verify that the data is not read from Arrow stream. It is inlined.
+    verify(transformWithStateInPySparkDeserializer, 
times(0)).readArrowBatches(any)
+    verify(listState).put(any)
   }
 
-  test("list state put") {
+  test("list state put - data via Arrow batch") {
     val message = ListStateCall.newBuilder().setStateName(stateName)
-      .setListStatePut(ListStatePut.newBuilder().build()).build()
+      
.setListStatePut(ListStatePut.newBuilder().setFetchWithArrow(true).build()).build()
     stateServer.handleListStateRequest(message)
-    verify(transformWithStateInPySparkDeserializer).readListElements(any, any)
+    verify(transformWithStateInPySparkDeserializer).readArrowBatches(any)
     verify(listState).put(any)
   }
 
@@ -336,11 +345,20 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
     verify(listState).appendValue(any[Row])
   }
 
-  test("list state append list") {
+  test("list state append list - inlined data") {
+    val message = ListStateCall.newBuilder().setStateName(stateName)
+      
.setAppendList(AppendList.newBuilder().setFetchWithArrow(false).build()).build()
+    stateServer.handleListStateRequest(message)
+    // Verify that the data is not read from Arrow stream. It is inlined.
+    verify(transformWithStateInPySparkDeserializer, 
times(0)).readArrowBatches(any)
+    verify(listState).appendList(any)
+  }
+
+  test("list state append list - data via Arrow batch") {
     val message = ListStateCall.newBuilder().setStateName(stateName)
-      .setAppendList(AppendList.newBuilder().build()).build()
+      
.setAppendList(AppendList.newBuilder().setFetchWithArrow(true).build()).build()
     stateServer.handleListStateRequest(message)
-    verify(transformWithStateInPySparkDeserializer).readListElements(any, any)
+    verify(transformWithStateInPySparkDeserializer).readArrowBatches(any)
     verify(listState).appendList(any)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to