This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2c0fc708b88f [SPARK-51891][SS] Squeeze the protocol of ListState GET / PUT / APPENDLIST for transformWithState in PySpark 2c0fc708b88f is described below commit 2c0fc708b88fef383b910c133aec97e1ec8a6a38 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Fri Apr 25 10:08:27 2025 +0900 [SPARK-51891][SS] Squeeze the protocol of ListState GET / PUT / APPENDLIST for transformWithState in PySpark ### What changes were proposed in this pull request? This PR proposes to squeeze the protocol of ListState GET / PUT / APPENDLIST for transformWithState in PySpark, which will help a lot on dealing with small list on ListState. Here are the changes: * ListState.get() no longer requires additional request to notice there is no further data to read. * We inline the data into proto message, to ease of determine whether the iterator has fully consumed or not. * ListState.put() / ListState.appendList() do not require additional request to send the data separately. * We inline the data into propo message if the length of list we pass is small enough (now it's "magically" set to 100 elements - need to look further) * If the length of list is over 100, we fall back to "old" Arrow send (rather than custom protocol). This is because of the fact pickled Python Row contains the schema information as string, which is larger than we anticipated. So in some point, Arrow would be more efficient. NOTE: 100 is a sort of "magic number", and we will need to improve this with more benchmarking. ### Why are the changes needed? To optimize further on ListState operations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50689 from HeartSaVioR/SPARK-51891. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/streaming/list_state_client.py | 81 +++++++++-- .../sql/streaming/proto/StateMessage_pb2.py | 156 +++++++++++---------- .../sql/streaming/proto/StateMessage_pb2.pyi | 72 ++++++++++ .../sql/streaming/stateful_processor_api_client.py | 12 ++ .../helper/helper_pandas_transform_with_state.py | 131 +++++++++++++++++ .../pandas/test_pandas_transform_with_state.py | 74 ++++++++++ .../sql/execution/streaming/StateMessage.proto | 11 ++ .../TransformWithStateInPySparkStateServer.scala | 69 ++++++++- ...ansformWithStateInPySparkStateServerSuite.scala | 62 +++++--- 9 files changed, 553 insertions(+), 115 deletions(-) diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index 66f2640c935e..08b672e86e08 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -37,7 +37,7 @@ class ListStateClient: self.schema = schema # A dictionary to store the mapping between list state name and a tuple of data batch # and the index of the last row that was read. - self.data_batch_dict: Dict[str, Tuple[Any, int]] = {} + self.data_batch_dict: Dict[str, Tuple[Any, int, bool]] = {} def exists(self, state_name: str) -> bool: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -61,12 +61,12 @@ class ListStateClient: f"Error checking value state exists: " f"{response_message[1]}" ) - def get(self, state_name: str, iterator_id: str) -> Tuple: + def get(self, state_name: str, iterator_id: str) -> Tuple[Tuple, bool]: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage if iterator_id in self.data_batch_dict: # If the state is already in the dictionary, return the next row. - data_batch, index = self.data_batch_dict[iterator_id] + data_batch, index, require_next_fetch = self.data_batch_dict[iterator_id] else: # If the state is not in the dictionary, fetch the state from the server. get_call = stateMessage.ListStateGet(iteratorId=iterator_id) @@ -79,23 +79,35 @@ class ListStateClient: message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - response_message = self._stateful_processor_api_client._receive_proto_message() + response_message = ( + self._stateful_processor_api_client._receive_proto_message_with_list_get() + ) status = response_message[0] if status == 0: - data_batch = self._stateful_processor_api_client._read_list_state() + data_batch = list( + map( + lambda x: self._stateful_processor_api_client._deserialize_from_bytes(x), + response_message[2], + ) + ) + require_next_fetch = response_message[3] index = 0 else: raise StopIteration() + is_last_row = False new_index = index + 1 if new_index < len(data_batch): # Update the index in the dictionary. - self.data_batch_dict[iterator_id] = (data_batch, new_index) + self.data_batch_dict[iterator_id] = (data_batch, new_index, require_next_fetch) else: # If the index is at the end of the data batch, remove the state from the dictionary. self.data_batch_dict.pop(iterator_id, None) + is_last_row = True + + is_last_row_from_iterator = is_last_row and not require_next_fetch row = data_batch[index] - return tuple(row) + return (tuple(row), is_last_row_from_iterator) def append_value(self, state_name: str, value: Tuple) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -118,7 +130,24 @@ class ListStateClient: def append_list(self, state_name: str, values: List[Tuple]) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - append_list_call = stateMessage.AppendList() + send_data_via_arrow = False + + # To workaround mypy type assignment check. + values_as_bytes: Any = [] + if len(values) == 100: + # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default + # value backed by various benchmarks. + # Arrow codepath + send_data_via_arrow = True + else: + values_as_bytes = map( + lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x), + values, + ) + + append_list_call = stateMessage.AppendList( + value=values_as_bytes, fetchWithArrow=send_data_via_arrow + ) list_state_call = stateMessage.ListStateCall( stateName=state_name, appendList=append_list_call ) @@ -127,7 +156,9 @@ class ListStateClient: self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_list_state(self.schema, values) + if send_data_via_arrow: + self._stateful_processor_api_client._send_arrow_state(self.schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: @@ -137,14 +168,32 @@ class ListStateClient: def put(self, state_name: str, values: List[Tuple]) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - put_call = stateMessage.ListStatePut() + send_data_via_arrow = False + # To workaround mypy type assignment check. + values_as_bytes: Any = [] + if len(values) == 100: + # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default + # value backed by various benchmarks. + send_data_via_arrow = True + else: + values_as_bytes = map( + lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x), + values, + ) + + put_call = stateMessage.ListStatePut( + value=values_as_bytes, fetchWithArrow=send_data_via_arrow + ) + list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call) state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_list_state(self.schema, values) + if send_data_via_arrow: + self._stateful_processor_api_client._send_arrow_state(self.schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: @@ -174,9 +223,17 @@ class ListStateIterator: # Generate a unique identifier for the iterator to make sure iterators from the same # list state do not interfere with each other. self.iterator_id = str(uuid.uuid4()) + self.iterator_fully_consumed = False def __iter__(self) -> Iterator[Tuple]: return self def __next__(self) -> Tuple: - return self.list_state_client.get(self.state_name, self.iterator_id) + if self.iterator_fully_consumed: + raise StopIteration() + + row, is_last_row = self.list_state_client.get(self.state_name, self.iterator_id) + if is_last_row: + self.iterator_fully_consumed = True + + return row diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py index 20af541f307c..094f1dd51c58 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py @@ -40,7 +40,7 @@ _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14st [...] + b'\n;org/apache/spark/sql/execution/streaming/StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\x84\x05\n\x0cStateRequest\x12\x18\n\x07version\x18\x01 \x01(\x05R\x07version\x12}\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00R\x15statefulProcessorCall\x12z\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00R\x14st [...] ) _globals = globals() @@ -50,8 +50,8 @@ _builder.BuildTopDescriptorsAndMessages( ) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals["_HANDLESTATE"]._serialized_start = 6408 - _globals["_HANDLESTATE"]._serialized_end = 6518 + _globals["_HANDLESTATE"]._serialized_start = 6695 + _globals["_HANDLESTATE"]._serialized_end = 6805 _globals["_STATEREQUEST"]._serialized_start = 112 _globals["_STATEREQUEST"]._serialized_end = 756 _globals["_STATERESPONSE"]._serialized_start = 758 @@ -60,78 +60,80 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 985 _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_start = 987 _globals["_STATERESPONSEWITHSTRINGTYPEVAL"]._serialized_end = 1109 - _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1112 - _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1784 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1787 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2128 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2131 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2390 - _globals["_TIMERREQUEST"]._serialized_start = 2393 - _globals["_TIMERREQUEST"]._serialized_end = 2650 - _globals["_TIMERVALUEREQUEST"]._serialized_start = 2653 - _globals["_TIMERVALUEREQUEST"]._serialized_end = 2899 - _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2901 - _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2967 - _globals["_GETPROCESSINGTIME"]._serialized_start = 2969 - _globals["_GETPROCESSINGTIME"]._serialized_end = 2988 - _globals["_GETWATERMARK"]._serialized_start = 2990 - _globals["_GETWATERMARK"]._serialized_end = 3004 - _globals["_UTILSREQUEST"]._serialized_start = 3007 - _globals["_UTILSREQUEST"]._serialized_end = 3146 - _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3148 - _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3191 - _globals["_STATECALLCOMMAND"]._serialized_start = 3194 - _globals["_STATECALLCOMMAND"]._serialized_end = 3393 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3396 - _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3691 - _globals["_VALUESTATECALL"]._serialized_start = 3694 - _globals["_VALUESTATECALL"]._serialized_end = 4096 - _globals["_LISTSTATECALL"]._serialized_start = 4099 - _globals["_LISTSTATECALL"]._serialized_end = 4706 - _globals["_MAPSTATECALL"]._serialized_start = 4709 - _globals["_MAPSTATECALL"]._serialized_end = 5543 - _globals["_SETIMPLICITKEY"]._serialized_start = 5545 - _globals["_SETIMPLICITKEY"]._serialized_end = 5579 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5581 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5600 - _globals["_EXISTS"]._serialized_start = 5602 - _globals["_EXISTS"]._serialized_end = 5610 - _globals["_GET"]._serialized_start = 5612 - _globals["_GET"]._serialized_end = 5617 - _globals["_REGISTERTIMER"]._serialized_start = 5619 - _globals["_REGISTERTIMER"]._serialized_end = 5680 - _globals["_DELETETIMER"]._serialized_start = 5682 - _globals["_DELETETIMER"]._serialized_end = 5741 - _globals["_LISTTIMERS"]._serialized_start = 5743 - _globals["_LISTTIMERS"]._serialized_end = 5787 - _globals["_VALUESTATEUPDATE"]._serialized_start = 5789 - _globals["_VALUESTATEUPDATE"]._serialized_end = 5829 - _globals["_CLEAR"]._serialized_start = 5831 - _globals["_CLEAR"]._serialized_end = 5838 - _globals["_LISTSTATEGET"]._serialized_start = 5840 - _globals["_LISTSTATEGET"]._serialized_end = 5886 - _globals["_LISTSTATEPUT"]._serialized_start = 5888 - _globals["_LISTSTATEPUT"]._serialized_end = 5902 - _globals["_APPENDVALUE"]._serialized_start = 5904 - _globals["_APPENDVALUE"]._serialized_end = 5939 - _globals["_APPENDLIST"]._serialized_start = 5941 - _globals["_APPENDLIST"]._serialized_end = 5953 - _globals["_GETVALUE"]._serialized_start = 5955 - _globals["_GETVALUE"]._serialized_end = 5991 - _globals["_CONTAINSKEY"]._serialized_start = 5993 - _globals["_CONTAINSKEY"]._serialized_end = 6032 - _globals["_UPDATEVALUE"]._serialized_start = 6034 - _globals["_UPDATEVALUE"]._serialized_end = 6095 - _globals["_ITERATOR"]._serialized_start = 6097 - _globals["_ITERATOR"]._serialized_end = 6139 - _globals["_KEYS"]._serialized_start = 6141 - _globals["_KEYS"]._serialized_end = 6179 - _globals["_VALUES"]._serialized_start = 6181 - _globals["_VALUES"]._serialized_end = 6221 - _globals["_REMOVEKEY"]._serialized_start = 6223 - _globals["_REMOVEKEY"]._serialized_end = 6260 - _globals["_SETHANDLESTATE"]._serialized_start = 6262 - _globals["_SETHANDLESTATE"]._serialized_end = 6361 - _globals["_TTLCONFIG"]._serialized_start = 6363 - _globals["_TTLCONFIG"]._serialized_end = 6406 + _globals["_STATERESPONSEWITHLISTGET"]._serialized_start = 1112 + _globals["_STATERESPONSEWITHLISTGET"]._serialized_end = 1272 + _globals["_STATEFULPROCESSORCALL"]._serialized_start = 1275 + _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1947 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1950 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 2291 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 2294 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 2553 + _globals["_TIMERREQUEST"]._serialized_start = 2556 + _globals["_TIMERREQUEST"]._serialized_end = 2813 + _globals["_TIMERVALUEREQUEST"]._serialized_start = 2816 + _globals["_TIMERVALUEREQUEST"]._serialized_end = 3062 + _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 3064 + _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 3130 + _globals["_GETPROCESSINGTIME"]._serialized_start = 3132 + _globals["_GETPROCESSINGTIME"]._serialized_end = 3151 + _globals["_GETWATERMARK"]._serialized_start = 3153 + _globals["_GETWATERMARK"]._serialized_end = 3167 + _globals["_UTILSREQUEST"]._serialized_start = 3170 + _globals["_UTILSREQUEST"]._serialized_end = 3309 + _globals["_PARSESTRINGSCHEMA"]._serialized_start = 3311 + _globals["_PARSESTRINGSCHEMA"]._serialized_end = 3354 + _globals["_STATECALLCOMMAND"]._serialized_start = 3357 + _globals["_STATECALLCOMMAND"]._serialized_end = 3556 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 3559 + _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 3854 + _globals["_VALUESTATECALL"]._serialized_start = 3857 + _globals["_VALUESTATECALL"]._serialized_end = 4259 + _globals["_LISTSTATECALL"]._serialized_start = 4262 + _globals["_LISTSTATECALL"]._serialized_end = 4869 + _globals["_MAPSTATECALL"]._serialized_start = 4872 + _globals["_MAPSTATECALL"]._serialized_end = 5706 + _globals["_SETIMPLICITKEY"]._serialized_start = 5708 + _globals["_SETIMPLICITKEY"]._serialized_end = 5742 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 5744 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 5763 + _globals["_EXISTS"]._serialized_start = 5765 + _globals["_EXISTS"]._serialized_end = 5773 + _globals["_GET"]._serialized_start = 5775 + _globals["_GET"]._serialized_end = 5780 + _globals["_REGISTERTIMER"]._serialized_start = 5782 + _globals["_REGISTERTIMER"]._serialized_end = 5843 + _globals["_DELETETIMER"]._serialized_start = 5845 + _globals["_DELETETIMER"]._serialized_end = 5904 + _globals["_LISTTIMERS"]._serialized_start = 5906 + _globals["_LISTTIMERS"]._serialized_end = 5950 + _globals["_VALUESTATEUPDATE"]._serialized_start = 5952 + _globals["_VALUESTATEUPDATE"]._serialized_end = 5992 + _globals["_CLEAR"]._serialized_start = 5994 + _globals["_CLEAR"]._serialized_end = 6001 + _globals["_LISTSTATEGET"]._serialized_start = 6003 + _globals["_LISTSTATEGET"]._serialized_end = 6049 + _globals["_LISTSTATEPUT"]._serialized_start = 6051 + _globals["_LISTSTATEPUT"]._serialized_end = 6127 + _globals["_APPENDVALUE"]._serialized_start = 6129 + _globals["_APPENDVALUE"]._serialized_end = 6164 + _globals["_APPENDLIST"]._serialized_start = 6166 + _globals["_APPENDLIST"]._serialized_end = 6240 + _globals["_GETVALUE"]._serialized_start = 6242 + _globals["_GETVALUE"]._serialized_end = 6278 + _globals["_CONTAINSKEY"]._serialized_start = 6280 + _globals["_CONTAINSKEY"]._serialized_end = 6319 + _globals["_UPDATEVALUE"]._serialized_start = 6321 + _globals["_UPDATEVALUE"]._serialized_end = 6382 + _globals["_ITERATOR"]._serialized_start = 6384 + _globals["_ITERATOR"]._serialized_end = 6426 + _globals["_KEYS"]._serialized_start = 6428 + _globals["_KEYS"]._serialized_end = 6466 + _globals["_VALUES"]._serialized_start = 6468 + _globals["_VALUES"]._serialized_end = 6508 + _globals["_REMOVEKEY"]._serialized_start = 6510 + _globals["_REMOVEKEY"]._serialized_end = 6547 + _globals["_SETHANDLESTATE"]._serialized_start = 6549 + _globals["_SETHANDLESTATE"]._serialized_end = 6648 + _globals["_TTLCONFIG"]._serialized_start = 6650 + _globals["_TTLCONFIG"]._serialized_end = 6693 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi index ac4b03b82034..aa86826862bb 100644 --- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi @@ -34,7 +34,9 @@ See the License for the specific language governing permissions and limitations under the License. """ import builtins +import collections.abc import google.protobuf.descriptor +import google.protobuf.internal.containers import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import sys @@ -229,6 +231,44 @@ class StateResponseWithStringTypeVal(google.protobuf.message.Message): global___StateResponseWithStringTypeVal = StateResponseWithStringTypeVal +class StateResponseWithListGet(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUSCODE_FIELD_NUMBER: builtins.int + ERRORMESSAGE_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + REQUIRENEXTFETCH_FIELD_NUMBER: builtins.int + statusCode: builtins.int + errorMessage: builtins.str + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + requireNextFetch: builtins.bool + def __init__( + self, + *, + statusCode: builtins.int = ..., + errorMessage: builtins.str = ..., + value: collections.abc.Iterable[builtins.bytes] | None = ..., + requireNextFetch: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "errorMessage", + b"errorMessage", + "requireNextFetch", + b"requireNextFetch", + "statusCode", + b"statusCode", + "value", + b"value", + ], + ) -> None: ... + +global___StateResponseWithListGet = StateResponseWithListGet + class StatefulProcessorCall(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1042,8 +1082,24 @@ global___ListStateGet = ListStateGet class ListStatePut(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALUE_FIELD_NUMBER: builtins.int + FETCHWITHARROW_FIELD_NUMBER: builtins.int + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + fetchWithArrow: builtins.bool def __init__( self, + *, + value: collections.abc.Iterable[builtins.bytes] | None = ..., + fetchWithArrow: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "fetchWithArrow", b"fetchWithArrow", "value", b"value" + ], ) -> None: ... global___ListStatePut = ListStatePut @@ -1065,8 +1121,24 @@ global___AppendValue = AppendValue class AppendList(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + VALUE_FIELD_NUMBER: builtins.int + FETCHWITHARROW_FIELD_NUMBER: builtins.int + @property + def value( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bytes]: ... + fetchWithArrow: builtins.bool def __init__( self, + *, + value: collections.abc.Iterable[builtins.bytes] | None = ..., + fetchWithArrow: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "fetchWithArrow", b"fetchWithArrow", "value", b"value" + ], ) -> None: ... global___AppendList = AppendList diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index e564d7186faa..18330c4096fa 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -425,6 +425,18 @@ class StatefulProcessorApiClient: message.ParseFromString(bytes) return message.statusCode, message.errorMessage, message.value + # The third return type is RepeatedScalarFieldContainer[bytes], which is protobuf's container + # type. We simplify it to Any here to avoid unnecessary complexity. + def _receive_proto_message_with_list_get(self) -> Tuple[int, str, Any, bool]: + import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage + + length = read_int(self.sockfile) + bytes = self.sockfile.read(length) + message = stateMessage.StateResponseWithListGet() + message.ParseFromString(bytes) + + return message.statusCode, message.errorMessage, message.value, message.requireNextFetch + def _receive_str(self) -> str: return self.utf8_deserializer.loads(self.sockfile) diff --git a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py index 53f6d77567ae..cc9f29609a5b 100644 --- a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py @@ -139,6 +139,14 @@ class ListStateProcessorFactory(StatefulProcessorFactory): return RowListStateProcessor() +class ListStateLargeListProcessorFactory(StatefulProcessorFactory): + def pandas(self): + return PandasListStateLargeListProcessor() + + def row(self): + return RowListStateLargeListProcessor() + + class ListStateLargeTTLProcessorFactory(StatefulProcessorFactory): def pandas(self): return PandasListStateLargeTTLProcessor() @@ -922,6 +930,129 @@ class RowListStateProcessor(StatefulProcessor): pass +class PandasListStateLargeListProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + list_state_schema = StructType([StructField("value", IntegerType(), True)]) + value_state_schema = StructType([StructField("size", IntegerType(), True)]) + self.list_state = handle.getListState("listState", list_state_schema) + self.list_size_state = handle.getValueState("listSizeState", value_state_schema) + + def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]: + elements_iter = self.list_state.get() + elements = list(elements_iter) + + # Use the magic number 100 to test with both inline proto case and Arrow case. + # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default + # value backed by various benchmarks. + # Put 90 elements per batch: + # 1st batch: read 0 element, and write 90 elements, read back 90 elements + # (both use inline proto) + # 2nd batch: read 90 elements, and write 90 elements, read back 180 elements + # (read uses both inline proto and Arrow, write uses Arrow) + + if len(elements) == 0: + # should be the first batch + assert self.list_size_state.get() is None + new_elements = [(i,) for i in range(90)] + if key == ("0",): + self.list_state.put(new_elements) + else: + self.list_state.appendList(new_elements) + self.list_size_state.update((len(new_elements),)) + else: + # check the elements + list_size = self.list_size_state.get() + assert list_size is not None + list_size = list_size[0] + assert list_size == len( + elements + ), f"list_size ({list_size}) != len(elements) ({len(elements)})" + + expected_elements_in_state = [(i,) for i in range(list_size)] + assert elements == expected_elements_in_state + + if key == ("0",): + # Use the operation `put` + new_elements = [(i,) for i in range(list_size + 90)] + self.list_state.put(new_elements) + final_size = len(new_elements) + self.list_size_state.update((final_size,)) + else: + # Use the operation `appendList` + new_elements = [(i,) for i in range(list_size, list_size + 90)] + self.list_state.appendList(new_elements) + final_size = len(new_elements) + list_size + self.list_size_state.update((final_size,)) + + prev_elements = ",".join(map(lambda x: str(x[0]), elements)) + updated_elements = ",".join(map(lambda x: str(x[0]), self.list_state.get())) + + yield pd.DataFrame( + {"id": key, "prevElements": prev_elements, "updatedElements": updated_elements} + ) + + +class RowListStateLargeListProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + list_state_schema = StructType([StructField("value", IntegerType(), True)]) + value_state_schema = StructType([StructField("size", IntegerType(), True)]) + self.list_state = handle.getListState("listState", list_state_schema) + self.list_size_state = handle.getValueState("listSizeState", value_state_schema) + + def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]: + elements_iter = self.list_state.get() + + elements = list(elements_iter) + + # Use the magic number 100 to test with both inline proto case and Arrow case. + # TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default + # value backed by various benchmarks. + # Put 90 elements per batch: + # 1st batch: read 0 element, and write 90 elements, read back 90 elements + # (both use inline proto) + # 2nd batch: read 90 elements, and write 90 elements, read back 180 elements + # (read uses both inline proto and Arrow, write uses Arrow) + + if len(elements) == 0: + # should be the first batch + assert self.list_size_state.get() is None + new_elements = [(i,) for i in range(90)] + if key == ("0",): + self.list_state.put(new_elements) + else: + self.list_state.appendList(new_elements) + self.list_size_state.update((len(new_elements),)) + else: + # check the elements + list_size = self.list_size_state.get() + assert list_size is not None + list_size = list_size[0] + assert list_size == len( + elements + ), f"list_size ({list_size}) != len(elements) ({len(elements)})" + + expected_elements_in_state = [(i,) for i in range(list_size)] + assert elements == expected_elements_in_state + + if key == ("0",): + # Use the operation `put` + new_elements = [(i,) for i in range(list_size + 90)] + self.list_state.put(new_elements) + final_size = len(new_elements) + self.list_size_state.update((final_size,)) + else: + # Use the operation `appendList` + new_elements = [(i,) for i in range(list_size, list_size + 90)] + self.list_state.appendList(new_elements) + final_size = len(new_elements) + list_size + self.list_size_state.update((final_size,)) + + prev_elements = ",".join(map(lambda x: str(x[0]), elements)) + updated_elements = ",".join(map(lambda x: str(x[0]), self.list_state.get())) + + yield Row(id=key[0], prevElements=prev_elements, updatedElements=updated_elements) + + class PandasListStateLargeTTLProcessor(PandasListStateProcessor): def init(self, handle: StatefulProcessorHandle) -> None: state_schema = StructType([StructField("temperature", IntegerType(), True)]) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 7f2469cd6b93..e36ae3a86a28 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -57,6 +57,7 @@ from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import ( TTLStatefulProcessorFactory, InvalidSimpleStatefulProcessorFactory, ListStateProcessorFactory, + ListStateLargeListProcessorFactory, ListStateLargeTTLProcessorFactory, MapStateProcessorFactory, MapStateLargeTTLProcessorFactory, @@ -302,6 +303,79 @@ class TransformWithStateTestsMixin: ListStateProcessorFactory(), check_results, True, "processingTime" ) + def test_transform_with_state_list_state_large_list(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + expected_prev_elements = "" + expected_updated_elements = ",".join(map(lambda x: str(x), range(90))) + else: + # batch_id == 1: + expected_prev_elements = ",".join(map(lambda x: str(x), range(90))) + expected_updated_elements = ",".join(map(lambda x: str(x), range(180))) + + assert set(batch_df.sort("id").collect()) == { + Row( + id="0", + prevElements=expected_prev_elements, + updatedElements=expected_updated_elements, + ), + Row( + id="1", + prevElements=expected_prev_elements, + updatedElements=expected_updated_elements, + ), + } + + input_path = tempfile.mkdtemp() + checkpoint_path = tempfile.mkdtemp() + + self._prepare_test_resource1(input_path) + time.sleep(2) + self._prepare_test_resource2(input_path) + + df = self._build_test_df(input_path) + + for q in self.spark.streams.active: + q.stop() + self.assertTrue(df.isStreaming) + + output_schema = StructType( + [ + StructField("id", StringType(), True), + StructField("prevElements", StringType(), True), + StructField("updatedElements", StringType(), True), + ] + ) + + stateful_processor = self.get_processor(ListStateLargeListProcessorFactory()) + if self.use_pandas(): + tws_df = df.groupBy("id").transformWithStateInPandas( + statefulProcessor=stateful_processor, + outputStructType=output_schema, + outputMode="Update", + timeMode="none", + ) + else: + tws_df = df.groupBy("id").transformWithState( + statefulProcessor=stateful_processor, + outputStructType=output_schema, + outputMode="Update", + timeMode="none", + ) + + q = ( + tws_df.writeStream.queryName("this_query") + .option("checkpointLocation", checkpoint_path) + .foreachBatch(check_results) + .outputMode("update") + .start() + ) + self.assertEqual(q.name, "this_query") + self.assertTrue(q.isActive) + q.processAllAvailable() + q.awaitTermination(10) + self.assertTrue(q.exception() is None) + # test list state with ttl has the same behavior as list state when state doesn't expire. def test_transform_with_state_list_state_large_ttl(self): def check_results(batch_df, batch_id): diff --git a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto index 1374bd100a2f..ce83c285410b 100644 --- a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -48,6 +48,13 @@ message StateResponseWithStringTypeVal { string value = 3; } +message StateResponseWithListGet { + int32 statusCode = 1; + string errorMessage = 2; + repeated bytes value = 3; + bool requireNextFetch = 4; +} + message StatefulProcessorCall { oneof method { SetHandleState setHandleState = 1; @@ -197,6 +204,8 @@ message ListStateGet { } message ListStatePut { + repeated bytes value = 1; + bool fetchWithArrow = 2; } message AppendValue { @@ -204,6 +213,8 @@ message AppendValue { } message AppendList { + repeated bytes value = 1; + bool fetchWithArrow = 2; } message GetValue { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala index 541ccf14b06d..a2c4d130ef31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala @@ -22,6 +22,7 @@ import java.nio.channels.{Channels, ServerSocketChannel} import java.time.Duration import scala.collection.mutable +import scala.jdk.CollectionConverters._ import com.google.protobuf.ByteString import org.apache.arrow.vector.VectorSchemaRoot @@ -38,6 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleImplBase, StatefulProcessorHandleState, StateVariableType} import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateResponseWithLongTypeVal, StateResponseWithStringTypeVal, StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, UtilsRequest, ValueStateCall} +import org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponseWithListGet import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, ValueState} import org.apache.spark.sql.types.{BinaryType, LongType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils @@ -481,7 +483,17 @@ class TransformWithStateInPySparkStateServer( sendResponse(2, s"state $stateName doesn't exist") } case ListStateCall.MethodCase.LISTSTATEPUT => - val rows = deserializer.readListElements(inputStream, listStateInfo) + val rows = if (message.getListStatePut.getFetchWithArrow) { + deserializer.readArrowBatches(inputStream) + } else { + val elements = message.getListStatePut.getValueList.asScala + elements.map { e => + PythonSQLUtils.toJVMRow( + e.toByteArray, + listStateInfo.schema, + listStateInfo.deserializer) + } + } listStateInfo.listState.put(rows.toArray) sendResponse(0) case ListStateCall.MethodCase.LISTSTATEGET => @@ -494,8 +506,7 @@ class TransformWithStateInPySparkStateServer( if (!iteratorOption.get.hasNext) { sendResponse(2, s"List state $stateName doesn't contain any value.") } else { - sendResponse(0) - sendIteratorForListState(iteratorOption.get) + sendResponseWithListGet(0, iter = iteratorOption.get) } case ListStateCall.MethodCase.APPENDVALUE => val byteArray = message.getAppendValue.getValue.toByteArray @@ -504,7 +515,17 @@ class TransformWithStateInPySparkStateServer( listStateInfo.listState.appendValue(newRow) sendResponse(0) case ListStateCall.MethodCase.APPENDLIST => - val rows = deserializer.readListElements(inputStream, listStateInfo) + val rows = if (message.getAppendList.getFetchWithArrow) { + deserializer.readArrowBatches(inputStream) + } else { + val elements = message.getAppendList.getValueList.asScala + elements.map { e => + PythonSQLUtils.toJVMRow( + e.toByteArray, + listStateInfo.schema, + listStateInfo.deserializer) + } + } listStateInfo.listState.appendList(rows.toArray) sendResponse(0) case ListStateCall.MethodCase.CLEAR => @@ -771,6 +792,46 @@ class TransformWithStateInPySparkStateServer( outputStream.write(responseMessageBytes) } + def sendResponseWithListGet( + status: Int, + errorMessage: String = null, + iter: Iterator[Row] = null): Unit = { + val responseMessageBuilder = StateResponseWithListGet.newBuilder() + .setStatusCode(status) + if (status != 0 && errorMessage != null) { + responseMessageBuilder.setErrorMessage(errorMessage) + } + + if (status == 0) { + // Only write a single batch in each GET request. Stops writing row if rowCount reaches + // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This is to handle a case + // when there are multiple state variables, user tries to access a different state variable + // while the current state variable is not exhausted yet. + var rowCount = 0 + while (iter.hasNext && rowCount < arrowTransformWithStateInPySparkMaxRecordsPerBatch) { + val data = iter.next() + + // Serialize the value row as a byte array + val valueBytes = PythonSQLUtils.toPyRow(data) + + responseMessageBuilder.addValue(ByteString.copyFrom(valueBytes)) + + rowCount += 1 + } + + assert(rowCount > 0, s"rowCount should be greater than 0 when status code is 0, " + + s"iter.hasNext ${iter.hasNext}") + + responseMessageBuilder.setRequireNextFetch(iter.hasNext) + } + + val responseMessage = responseMessageBuilder.build() + val responseMessageBytes = responseMessage.toByteArray + val byteLength = responseMessageBytes.length + outputStream.writeInt(byteLength) + outputStream.write(responseMessageBytes) + } + def sendIteratorAsArrowBatches[T]( iter: Iterator[T], outputSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala index 318fd6ce1c8f..fd750edf96ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala @@ -262,10 +262,10 @@ class TransformWithStateInPySparkStateServerSuite extends SparkFunSuite with Bef .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() stateServer.handleListStateRequest(message) verify(listState, times(0)).get() - // 1 for row, 1 for end of the data, 1 for proto response - verify(outputStream, times(3)).writeInt(any) - // 1 for sending an actual row, 1 for sending proto message - verify(outputStream, times(2)).write(any[Array[Byte]]) + // 1 for proto response + verify(outputStream).writeInt(any) + // 1 for sending proto message + verify(outputStream).write(any[Array[Byte]]) } test("list state get - iterator in map with multiple batches") { @@ -282,20 +282,20 @@ class TransformWithStateInPySparkStateServerSuite extends SparkFunSuite with Bef // First call should send 2 records. stateServer.handleListStateRequest(message) verify(listState, times(0)).get() - // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response - verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any) - // maxRecordsPerBatch times for rows, 1 for sending proto message - verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]]) + // 1 for proto response + verify(outputStream).writeInt(any) + // 1 for proto message + verify(outputStream).write(any[Array[Byte]]) // Second call should send the remaining 2 records. stateServer.handleListStateRequest(message) verify(listState, times(0)).get() // Since Mockito's verify counts the total number of calls, the expected number of writeInt // and write should be accumulated from the prior count; the number of calls are the same // with prior one. - // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response - verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any) - // maxRecordsPerBatch times for rows, 1 for sending proto message - verify(outputStream, times(maxRecordsPerBatch * 2 + 2)).write(any[Array[Byte]]) + // 1 for proto response + verify(outputStream, times(2)).writeInt(any) + // 1 for sending proto message + verify(outputStream, times(2)).write(any[Array[Byte]]) } test("list state get - iterator not in map") { @@ -314,17 +314,26 @@ class TransformWithStateInPySparkStateServerSuite extends SparkFunSuite with Bef // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still // having 1 row left in the iterator. - // maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto response - verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any) - // 2 for rows, 1 for proto message - verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]]) + // 1 for proto response + verify(outputStream, times(1)).writeInt(any) + // 1 for proto message + verify(outputStream, times(1)).write(any[Array[Byte]]) + } + + test("list state put - inlined data") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStatePut(ListStatePut.newBuilder().setFetchWithArrow(false).build()).build() + stateServer.handleListStateRequest(message) + // Verify that the data is not read from Arrow stream. It is inlined. + verify(transformWithStateInPySparkDeserializer, times(0)).readArrowBatches(any) + verify(listState).put(any) } - test("list state put") { + test("list state put - data via Arrow batch") { val message = ListStateCall.newBuilder().setStateName(stateName) - .setListStatePut(ListStatePut.newBuilder().build()).build() + .setListStatePut(ListStatePut.newBuilder().setFetchWithArrow(true).build()).build() stateServer.handleListStateRequest(message) - verify(transformWithStateInPySparkDeserializer).readListElements(any, any) + verify(transformWithStateInPySparkDeserializer).readArrowBatches(any) verify(listState).put(any) } @@ -336,11 +345,20 @@ class TransformWithStateInPySparkStateServerSuite extends SparkFunSuite with Bef verify(listState).appendValue(any[Row]) } - test("list state append list") { + test("list state append list - inlined data") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setAppendList(AppendList.newBuilder().setFetchWithArrow(false).build()).build() + stateServer.handleListStateRequest(message) + // Verify that the data is not read from Arrow stream. It is inlined. + verify(transformWithStateInPySparkDeserializer, times(0)).readArrowBatches(any) + verify(listState).appendList(any) + } + + test("list state append list - data via Arrow batch") { val message = ListStateCall.newBuilder().setStateName(stateName) - .setAppendList(AppendList.newBuilder().build()).build() + .setAppendList(AppendList.newBuilder().setFetchWithArrow(true).build()).build() stateServer.handleListStateRequest(message) - verify(transformWithStateInPySparkDeserializer).readListElements(any, any) + verify(transformWithStateInPySparkDeserializer).readArrowBatches(any) verify(listState).appendList(any) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org