This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f663bf58313 [SPARK-53870][PYTHON][SS] Fix partial read bug for large
proto messages in TransformWithStateInPySparkStateServer
3f663bf58313 is described below
commit 3f663bf583135295dcaba9e03fe9a722eb55665b
Author: Jason Teoh <[email protected]>
AuthorDate: Mon Oct 13 07:03:22 2025 +0900
[SPARK-53870][PYTHON][SS] Fix partial read bug for large proto messages in
TransformWithStateInPySparkStateServer
### What changes were proposed in this pull request?
Fix the TransformWithState StateServer's `parseProtoMessage` method to
fully read the desired message using the correct [readFully DataInputStream
API](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataInput.html#readFully(byte%5B%5D))
rather than `read` (InputStream/FilterInputStream) which only reads all
available data and may not return the full message. [`readFully`
(DataInputStream)](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataI
[...]
In addition to the linked API above, this StackOverflow post also
illustrates the difference between the two APIs:
https://stackoverflow.com/a/25900095
### Why are the changes needed?
For large state values used in the TransformWithState API,
`inputStream.read` is not guaranteed to read `messageLen`'s bytes of data as
per the InputStream API. For large values, `read` will return prematurely and
the messageBytes will only be partially filled, yielding an incorrect and
likely unparseable proto message.
This is not a common scenario, as testing also indicated that the actual
proto messages had to be somewhat large to consistently trigger this error. The
test case I added uses 512KB strings in the state value updates.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added a new test case using 512KB strings:
- Value state update
- List state update with 3 (different) values (note: list state provides a
multi-value update API, so this message is even larger than the other two)
- Map state update with single key/value
```
build/sbt -Phive -Phive-thriftserver -DskipTests package
python/run-tests --testnames
'pyspark.sql.tests.pandas.test_pandas_transform_with_state
TransformWithStateInPandasTests'
python/run-tests --testnames
'pyspark.sql.tests.pandas.test_pandas_transform_with_state
TransformWithStateInPySparkTests'
```
The configured data size (512KB) triggers an incomplete read, while also
completing in a reasonable time (within 30s on my laptop). I had separately
tested a larger input size of 4MB which took 30min which I considered too
expensive to include in the test.
Below is sample/testing results from using `read` only (i.e., no fix) and
adding a check on message length vs read bytes ([test code is included in this
commit](https://github.com/apache/spark/pull/52539/commits/b68cfd7c814f7050515e785d6813b624d68c3a59)
but reverted later for the PR). The check is no longer required after the
`readFully` fix as that is handled within the provided API.
```
TransformWithStateInPandasTests
pyspark.errors.exceptions.base.PySparkRuntimeError: Error updating
map state value: TESTING: Failed to read message bytes: expected 524369 bytes,
but only read 261312 bytes
TransformWithStateInPySparkTests
pyspark.errors.exceptions.base.PySparkRuntimeError: Error updating
value state: TESTING: Failed to read message bytes: expected 524336 bytes, but
only read 392012 bytes
```
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude Code (claude-sonnet-4-5-20250929)
Closes #52539 from jiateoh/tws_readFully_fix.
Lead-authored-by: Jason Teoh <[email protected]>
Co-authored-by: Jason Teoh <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../helper/helper_pandas_transform_with_state.py | 92 ++++++++++++++++++++++
.../pandas/test_pandas_transform_with_state.py | 45 +++++++++++
.../TransformWithStateInPySparkStateServer.scala | 2 +-
3 files changed, 138 insertions(+), 1 deletion(-)
diff --git
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index 2946b894b7f8..09ef3a447f9c 100644
---
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -255,6 +255,14 @@ class
CompositeOutputProcessorFactory(StatefulProcessorFactory):
return RowCompositeOutputProcessor()
+class LargeValueStatefulProcessorFactory(StatefulProcessorFactory):
+ def pandas(self):
+ return PandasLargeValueStatefulProcessor()
+
+ def row(self):
+ return RowLargeValueStatefulProcessor()
+
+
# StatefulProcessor implementations
@@ -2039,3 +2047,87 @@ class RowCompositeOutputProcessor(StatefulProcessor):
def close(self) -> None:
pass
+
+
+class PandasLargeValueStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle):
+ # Test all three state types with large values
+ value_state_schema = StructType([StructField("value", StringType(),
True)])
+ self.value_state = handle.getValueState("valueState",
value_state_schema)
+
+ list_state_schema = StructType([StructField("value", StringType(),
True)])
+ self.list_state = handle.getListState("listState", list_state_schema)
+
+ self.map_state = handle.getMapState("mapState", "key string", "value
string")
+
+ def handleInputRows(self, key, rows, timerValues) ->
Iterator[pd.DataFrame]:
+ # Create a large string (512 KB)
+ target_size_bytes = 512 * 1024
+ large_string = "a" * target_size_bytes
+
+ # Test ValueState with large string
+ self.value_state.update((large_string,))
+ value_retrieved = self.value_state.get()[0]
+
+ # Test ListState with large strings
+ self.list_state.put([(large_string,), (large_string + "b",),
(large_string + "c",)])
+ list_retrieved = list(self.list_state.get())
+ list_elements = ",".join([elem[0] for elem in list_retrieved])
+
+ # Test MapState with large strings
+ map_key = ("large_string_key",)
+ self.map_state.updateValue(map_key, (large_string,))
+ map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"
+
+ yield pd.DataFrame(
+ {
+ "id": key,
+ "valueStateResult": [value_retrieved],
+ "listStateResult": [list_elements],
+ "mapStateResult": [map_retrieved],
+ }
+ )
+
+ def close(self) -> None:
+ pass
+
+
+class RowLargeValueStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle):
+ # Test all three state types with large values
+ value_state_schema = StructType([StructField("value", StringType(),
True)])
+ self.value_state = handle.getValueState("valueState",
value_state_schema)
+
+ list_state_schema = StructType([StructField("value", StringType(),
True)])
+ self.list_state = handle.getListState("listState", list_state_schema)
+
+ self.map_state = handle.getMapState("mapState", "key string", "value
string")
+
+ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+ # Create a large string (512 KB)
+ target_size_bytes = 512 * 1024
+ large_string = "a" * target_size_bytes
+
+ # Test ValueState with large string
+ self.value_state.update((large_string,))
+ value_retrieved = self.value_state.get()[0]
+
+ # Test ListState with large strings
+ self.list_state.put([(large_string,), (large_string + "b",),
(large_string + "c",)])
+ list_retrieved = list(self.list_state.get())
+ list_elements = ",".join([elem[0] for elem in list_retrieved])
+
+ # Test MapState with large strings
+ map_key = ("large_string_key",)
+ self.map_state.updateValue(map_key, (large_string,))
+ map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"
+
+ yield Row(
+ id=key[0],
+ valueStateResult=value_retrieved,
+ listStateResult=list_elements,
+ mapStateResult=map_retrieved,
+ )
+
+ def close(self) -> None:
+ pass
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 8359c08a43a6..6932ef28fc59 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -65,6 +65,7 @@ from
pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
ListStateLargeTTLProcessorFactory,
MapStateProcessorFactory,
MapStateLargeTTLProcessorFactory,
+ LargeValueStatefulProcessorFactory,
BasicProcessorFactory,
BasicProcessorNotNullableFactory,
AddFieldsProcessorFactory,
@@ -2264,6 +2265,50 @@ class TransformWithStateTestsMixin:
),
)
+ # test all state types (value, list, map) with large values (512 KB)
+ def test_transform_with_state_large_values(self):
+ def check_results(batch_df, batch_id):
+ batch_df.collect()
+ # Create expected large string (512 KB)
+ target_size_bytes = 512 * 1024
+ large_string = "a" * target_size_bytes
+ expected_list_elements = ",".join(
+ [large_string, large_string + "b", large_string + "c"]
+ )
+ expected_map_result = f"large_string_key:{large_string}"
+
+ assert set(batch_df.sort("id").collect()) == {
+ Row(
+ id="0",
+ valueStateResult=large_string,
+ listStateResult=expected_list_elements,
+ mapStateResult=expected_map_result,
+ ),
+ Row(
+ id="1",
+ valueStateResult=large_string,
+ listStateResult=expected_list_elements,
+ mapStateResult=expected_map_result,
+ ),
+ }
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("valueStateResult", StringType(), True),
+ StructField("listStateResult", StringType(), True),
+ StructField("mapStateResult", StringType(), True),
+ ]
+ )
+
+ self._test_transform_with_state_basic(
+ LargeValueStatefulProcessorFactory(),
+ check_results,
+ True,
+ "None",
+ output_schema=output_schema,
+ )
+
@unittest.skipIf(
not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index 937b6232ee4a..4fee6a6e71d3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -208,7 +208,7 @@ class TransformWithStateInPySparkStateServer(
private def parseProtoMessage(): StateRequest = {
val messageLen = inputStream.readInt()
val messageBytes = new Array[Byte](messageLen)
- inputStream.read(messageBytes)
+ inputStream.readFully(messageBytes)
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]