This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 746db3d07860 [SPARK-53743][SS] Remove the usage of fetchWithArrow in
ListState.put/appendList
746db3d07860 is described below
commit 746db3d07860a919da1c3c49ef0e4e950f80cd2c
Author: Jungtaek Lim <[email protected]>
AuthorDate: Tue Sep 30 10:04:03 2025 +0900
[SPARK-53743][SS] Remove the usage of fetchWithArrow in
ListState.put/appendList
### What changes were proposed in this pull request?
This PR proposes to remove the usage of fetchWithArrow in
ListState.put/appendList.
(We don't remove the fetchWithArrow and its proto, since it does not remove
noticeable complexity and removing something from proto may bring some
unexpected side effect on compatibility.)
### Why are the changes needed?
We have observed the case where Arrow path of sending the list has some
issue, while normal path does not have an issue.
The case is to have `None` value in IntegerType() in the element of list
state - the column is set to nullable=True hence that should be allowed, but
the error is raised during the conversion.
```
File
"/databricks/spark/python/pyspark/sql/streaming/stateful_processor.py", line
147, in put
self._listStateClient.put(self._stateName, newState)
File
"/databricks/spark/python/pyspark/sql/streaming/list_state_client.py", line
195, in put
self._stateful_processor_api_client._send_arrow_state(self.schema,
values)
File
"/spark/python/pyspark/sql/streaming/stateful_processor_api_client.py", line
604, in _send_arrow_state
pandas_df = convert_pandas_using_numpy_type(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spark/python/pyspark/sql/pandas/types.py", line 1599, in
convert_pandas_using_numpy_type
df[field.name] = df[field.name].astype(np_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python/lib/python3.12/site-packages/pandas/core/generic.py", line
6643, in astype
new_data = self._mgr.astype(dtype=dtype, copy=copy, errors=errors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/python/lib/python3.12/site-packages/pandas/core/internals/managers.py", line
430, in astype
return self.apply(
^^^^^^^^^^^
File
"/python/lib/python3.12/site-packages/pandas/core/internals/managers.py", line
363, in apply
applied = getattr(b, f)(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File
"/python/lib/python3.12/site-packages/pandas/core/internals/blocks.py", line
758, in astype
new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py",
line 237, in astype_array_safe
new_values = astype_array(values, dtype, copy=copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py",
line 182, in astype_array
values = _astype_nansafe(values, dtype, copy=copy)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py",
line 133, in _astype_nansafe
return arr.astype(dtype, copy=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: int() argument must be a string, a bytes-like object or a real
number, not 'NoneType'
```
Since we don't know how useful the Arrow based sending list is, it'd be
better not to try to fix the issue in the Arrow code path at this point and
just remove it.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Updated the existing test to test the observed case.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52479 from HeartSaVioR/SPARK-53743.
Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/streaming/list_state_client.py | 48 ++++-------------
.../helper/helper_pandas_transform_with_state.py | 60 ++++++++++++----------
.../pandas/test_pandas_transform_with_state.py | 6 +--
.../TransformWithStateInPySparkStateServer.scala | 8 +++
4 files changed, 55 insertions(+), 67 deletions(-)
diff --git a/python/pyspark/sql/streaming/list_state_client.py
b/python/pyspark/sql/streaming/list_state_client.py
index 08b672e86e08..89de69ed0625 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -130,24 +130,12 @@ class ListStateClient:
def append_list(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- send_data_via_arrow = False
-
- # To workaround mypy type assignment check.
- values_as_bytes: Any = []
- if len(values) == 100:
- # TODO(SPARK-51907): Let's update this to be either flexible or
more reasonable default
- # value backed by various benchmarks.
- # Arrow codepath
- send_data_via_arrow = True
- else:
- values_as_bytes = map(
- lambda x:
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
- values,
- )
-
- append_list_call = stateMessage.AppendList(
- value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+ values_as_bytes = map(
+ lambda x:
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+ values,
)
+
+ append_list_call = stateMessage.AppendList(value=values_as_bytes,
fetchWithArrow=False)
list_state_call = stateMessage.ListStateCall(
stateName=state_name, appendList=append_list_call
)
@@ -156,9 +144,6 @@ class ListStateClient:
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- if send_data_via_arrow:
- self._stateful_processor_api_client._send_arrow_state(self.schema,
values)
-
response_message =
self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
@@ -168,32 +153,19 @@ class ListStateClient:
def put(self, state_name: str, values: List[Tuple]) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- send_data_via_arrow = False
- # To workaround mypy type assignment check.
- values_as_bytes: Any = []
- if len(values) == 100:
- # TODO(SPARK-51907): Let's update this to be either flexible or
more reasonable default
- # value backed by various benchmarks.
- send_data_via_arrow = True
- else:
- values_as_bytes = map(
- lambda x:
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
- values,
- )
-
- put_call = stateMessage.ListStatePut(
- value=values_as_bytes, fetchWithArrow=send_data_via_arrow
+ values_as_bytes = map(
+ lambda x:
self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
+ values,
)
+ put_call = stateMessage.ListStatePut(value=values_as_bytes,
fetchWithArrow=False)
+
list_state_call = stateMessage.ListStateCall(stateName=state_name,
listStatePut=put_call)
state_variable_request =
stateMessage.StateVariableRequest(listStateCall=list_state_call)
message =
stateMessage.StateRequest(stateVariableRequest=state_variable_request)
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- if send_data_via_arrow:
- self._stateful_processor_api_client._send_arrow_state(self.schema,
values)
-
response_message =
self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
diff --git
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index d258f693ccb8..a35bae88bedb 100644
---
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -942,7 +942,12 @@ class RowListStateProcessor(StatefulProcessor):
class PandasListStateLargeListProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
- list_state_schema = StructType([StructField("value", IntegerType(),
True)])
+ list_state_schema = StructType(
+ [
+ StructField("value", IntegerType(), True),
+ StructField("valueNull", IntegerType(), True),
+ ]
+ )
value_state_schema = StructType([StructField("size", IntegerType(),
True)])
self.list_state = handle.getListState("listState", list_state_schema)
self.list_size_state = handle.getValueState("listSizeState",
value_state_schema)
@@ -952,18 +957,15 @@ class
PandasListStateLargeListProcessor(StatefulProcessor):
elements = list(elements_iter)
# Use the magic number 100 to test with both inline proto case and
Arrow case.
- # TODO(SPARK-51907): Let's update this to be either flexible or more
reasonable default
- # value backed by various benchmarks.
- # Put 90 elements per batch:
- # 1st batch: read 0 element, and write 90 elements, read back 90
elements
- # (both use inline proto)
- # 2nd batch: read 90 elements, and write 90 elements, read back 180
elements
- # (read uses both inline proto and Arrow, write uses Arrow)
+ # Now the magic number is not actually used, but this is to make this
test be a regression
+ # test of SPARK-53743.
+ # Explicitly put 100 elements of list which triggered Arrow based list
serialization before
+ # SPARK-53743.
if len(elements) == 0:
# should be the first batch
assert self.list_size_state.get() is None
- new_elements = [(i,) for i in range(90)]
+ new_elements = [(i, None) for i in range(100)]
if key == ("0",):
self.list_state.put(new_elements)
else:
@@ -978,18 +980,20 @@ class
PandasListStateLargeListProcessor(StatefulProcessor):
elements
), f"list_size ({list_size}) != len(elements) ({len(elements)})"
- expected_elements_in_state = [(i,) for i in range(list_size)]
- assert elements == expected_elements_in_state
+ expected_elements_in_state = [(i, None) for i in range(list_size)]
+ assert (
+ elements == expected_elements_in_state
+ ), f"expected {expected_elements_in_state} but got {elements}"
if key == ("0",):
# Use the operation `put`
- new_elements = [(i,) for i in range(list_size + 90)]
+ new_elements = [(i, None) for i in range(list_size + 90)]
self.list_state.put(new_elements)
final_size = len(new_elements)
self.list_size_state.update((final_size,))
else:
# Use the operation `appendList`
- new_elements = [(i,) for i in range(list_size, list_size + 90)]
+ new_elements = [(i, None) for i in range(list_size, list_size
+ 90)]
self.list_state.appendList(new_elements)
final_size = len(new_elements) + list_size
self.list_size_state.update((final_size,))
@@ -1004,7 +1008,12 @@ class
PandasListStateLargeListProcessor(StatefulProcessor):
class RowListStateLargeListProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
- list_state_schema = StructType([StructField("value", IntegerType(),
True)])
+ list_state_schema = StructType(
+ [
+ StructField("value", IntegerType(), True),
+ StructField("valueNull", IntegerType(), True),
+ ]
+ )
value_state_schema = StructType([StructField("size", IntegerType(),
True)])
self.list_state = handle.getListState("listState", list_state_schema)
self.list_size_state = handle.getValueState("listSizeState",
value_state_schema)
@@ -1015,18 +1024,15 @@ class RowListStateLargeListProcessor(StatefulProcessor):
elements = list(elements_iter)
# Use the magic number 100 to test with both inline proto case and
Arrow case.
- # TODO(SPARK-51907): Let's update this to be either flexible or more
reasonable default
- # value backed by various benchmarks.
- # Put 90 elements per batch:
- # 1st batch: read 0 element, and write 90 elements, read back 90
elements
- # (both use inline proto)
- # 2nd batch: read 90 elements, and write 90 elements, read back 180
elements
- # (read uses both inline proto and Arrow, write uses Arrow)
+ # Now the magic number is not actually used, but this is to make this
test be a regression
+ # test of SPARK-53743.
+ # Explicitly put 100 elements of list which triggered Arrow based list
serialization before
+ # SPARK-53743.
if len(elements) == 0:
# should be the first batch
assert self.list_size_state.get() is None
- new_elements = [(i,) for i in range(90)]
+ new_elements = [(i, None) for i in range(100)]
if key == ("0",):
self.list_state.put(new_elements)
else:
@@ -1041,18 +1047,20 @@ class RowListStateLargeListProcessor(StatefulProcessor):
elements
), f"list_size ({list_size}) != len(elements) ({len(elements)})"
- expected_elements_in_state = [(i,) for i in range(list_size)]
- assert elements == expected_elements_in_state
+ expected_elements_in_state = [(i, None) for i in range(list_size)]
+ assert (
+ elements == expected_elements_in_state
+ ), f"expected {expected_elements_in_state} but got {elements}"
if key == ("0",):
# Use the operation `put`
- new_elements = [(i,) for i in range(list_size + 90)]
+ new_elements = [(i, None) for i in range(list_size + 90)]
self.list_state.put(new_elements)
final_size = len(new_elements)
self.list_size_state.update((final_size,))
else:
# Use the operation `appendList`
- new_elements = [(i,) for i in range(list_size, list_size + 90)]
+ new_elements = [(i, None) for i in range(list_size, list_size
+ 90)]
self.list_state.appendList(new_elements)
final_size = len(new_elements) + list_size
self.list_size_state.update((final_size,))
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 6d79a8c26753..af44093c512d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -312,11 +312,11 @@ class TransformWithStateTestsMixin:
batch_df.collect()
if batch_id == 0:
expected_prev_elements = ""
- expected_updated_elements = ",".join(map(lambda x: str(x),
range(90)))
+ expected_updated_elements = ",".join(map(lambda x: str(x),
range(100)))
else:
# batch_id == 1:
- expected_prev_elements = ",".join(map(lambda x: str(x),
range(90)))
- expected_updated_elements = ",".join(map(lambda x: str(x),
range(180)))
+ expected_prev_elements = ",".join(map(lambda x: str(x),
range(100)))
+ expected_updated_elements = ",".join(map(lambda x: str(x),
range(190)))
assert set(batch_df.sort("id").collect()) == {
Row(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index f5fec2f85dff..937b6232ee4a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -490,6 +490,10 @@ class TransformWithStateInPySparkStateServer(
sendResponse(2, s"state $stateName doesn't exist")
}
case ListStateCall.MethodCase.LISTSTATEPUT =>
+ // TODO: Check whether we can safely remove fetchWithArrow without
breaking backward
+ // compatibility (Spark Connect)
+ // TODO: Also check whether fetchWithArrow has a clear benefit to be
retained (in terms
+ // of performance)
val rows = if (message.getListStatePut.getFetchWithArrow) {
deserializer.readArrowBatches(inputStream)
} else {
@@ -522,6 +526,10 @@ class TransformWithStateInPySparkStateServer(
listStateInfo.listState.appendValue(newRow)
sendResponse(0)
case ListStateCall.MethodCase.APPENDLIST =>
+ // TODO: Check whether we can safely remove fetchWithArrow without
breaking backward
+ // compatibility (Spark Connect)
+ // TODO: Also check whether fetchWithArrow has a clear benefit to be
retained (in terms
+ // of performance)
val rows = if (message.getAppendList.getFetchWithArrow) {
deserializer.readArrowBatches(inputStream)
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]