This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f663bf58313 [SPARK-53870][PYTHON][SS] Fix partial read bug for large 
proto messages in TransformWithStateInPySparkStateServer
3f663bf58313 is described below

commit 3f663bf583135295dcaba9e03fe9a722eb55665b
Author: Jason Teoh <[email protected]>
AuthorDate: Mon Oct 13 07:03:22 2025 +0900

    [SPARK-53870][PYTHON][SS] Fix partial read bug for large proto messages in 
TransformWithStateInPySparkStateServer
    
    ### What changes were proposed in this pull request?
    
    Fix the TransformWithState StateServer's `parseProtoMessage` method to 
fully read the desired message using the correct [readFully DataInputStream 
API](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataInput.html#readFully(byte%5B%5D))
 rather than `read` (InputStream/FilterInputStream) which only reads all 
available data and may not return the full message. [`readFully` 
(DataInputStream)](https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/io/DataI
 [...]
    
    In addition to the linked API above, this StackOverflow post also 
illustrates the difference between the two APIs: 
https://stackoverflow.com/a/25900095
    
    ### Why are the changes needed?
    
    For large state values used in the TransformWithState API, 
`inputStream.read` is not guaranteed to read `messageLen`'s bytes of data as 
per the InputStream API. For large values, `read` will return prematurely and 
the messageBytes will only be partially filled, yielding an incorrect and 
likely unparseable proto message.
    
    This is not a common scenario, as testing also indicated that the actual 
proto messages had to be somewhat large to consistently trigger this error. The 
test case I added uses 512KB strings in the state value updates.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added a new test case using 512KB strings:
    
    - Value state update
    - List state update with 3 (different) values (note: list state provides a 
multi-value update API, so this message is even larger than the other two)
    - Map state update with single key/value
    
    ```
    build/sbt -Phive -Phive-thriftserver -DskipTests package
    python/run-tests --testnames 
'pyspark.sql.tests.pandas.test_pandas_transform_with_state 
TransformWithStateInPandasTests'
    python/run-tests --testnames 
'pyspark.sql.tests.pandas.test_pandas_transform_with_state 
TransformWithStateInPySparkTests'
    ```
    
    The configured data size (512KB) triggers an incomplete read, while also 
completing in a reasonable time (within 30s on my laptop). I had separately 
tested a larger input size of 4MB which took 30min which I considered too 
expensive to include in the test.
    Below is sample/testing results from using `read` only (i.e., no fix) and 
adding a check on message length vs read bytes ([test code is included in this 
commit](https://github.com/apache/spark/pull/52539/commits/b68cfd7c814f7050515e785d6813b624d68c3a59)
 but reverted later for the PR). The check is no longer required after the 
`readFully` fix as that is handled within the provided API.
    ```
        TransformWithStateInPandasTests
            pyspark.errors.exceptions.base.PySparkRuntimeError: Error updating 
map state value: TESTING: Failed to read message bytes: expected 524369 bytes, 
but only read 261312 bytes
    
        TransformWithStateInPySparkTests
            pyspark.errors.exceptions.base.PySparkRuntimeError: Error updating 
value state: TESTING: Failed to read message bytes: expected 524336 bytes, but 
only read 392012 bytes
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude Code (claude-sonnet-4-5-20250929)
    
    Closes #52539 from jiateoh/tws_readFully_fix.
    
    Lead-authored-by: Jason Teoh <[email protected]>
    Co-authored-by: Jason Teoh <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../helper/helper_pandas_transform_with_state.py   | 92 ++++++++++++++++++++++
 .../pandas/test_pandas_transform_with_state.py     | 45 +++++++++++
 .../TransformWithStateInPySparkStateServer.scala   |  2 +-
 3 files changed, 138 insertions(+), 1 deletion(-)

diff --git 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
index 2946b894b7f8..09ef3a447f9c 100644
--- 
a/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py
@@ -255,6 +255,14 @@ class 
CompositeOutputProcessorFactory(StatefulProcessorFactory):
         return RowCompositeOutputProcessor()
 
 
+class LargeValueStatefulProcessorFactory(StatefulProcessorFactory):
+    def pandas(self):
+        return PandasLargeValueStatefulProcessor()
+
+    def row(self):
+        return RowLargeValueStatefulProcessor()
+
+
 # StatefulProcessor implementations
 
 
@@ -2039,3 +2047,87 @@ class RowCompositeOutputProcessor(StatefulProcessor):
 
     def close(self) -> None:
         pass
+
+
+class PandasLargeValueStatefulProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle):
+        # Test all three state types with large values
+        value_state_schema = StructType([StructField("value", StringType(), 
True)])
+        self.value_state = handle.getValueState("valueState", 
value_state_schema)
+
+        list_state_schema = StructType([StructField("value", StringType(), 
True)])
+        self.list_state = handle.getListState("listState", list_state_schema)
+
+        self.map_state = handle.getMapState("mapState", "key string", "value 
string")
+
+    def handleInputRows(self, key, rows, timerValues) -> 
Iterator[pd.DataFrame]:
+        # Create a large string (512 KB)
+        target_size_bytes = 512 * 1024
+        large_string = "a" * target_size_bytes
+
+        # Test ValueState with large string
+        self.value_state.update((large_string,))
+        value_retrieved = self.value_state.get()[0]
+
+        # Test ListState with large strings
+        self.list_state.put([(large_string,), (large_string + "b",), 
(large_string + "c",)])
+        list_retrieved = list(self.list_state.get())
+        list_elements = ",".join([elem[0] for elem in list_retrieved])
+
+        # Test MapState with large strings
+        map_key = ("large_string_key",)
+        self.map_state.updateValue(map_key, (large_string,))
+        map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"
+
+        yield pd.DataFrame(
+            {
+                "id": key,
+                "valueStateResult": [value_retrieved],
+                "listStateResult": [list_elements],
+                "mapStateResult": [map_retrieved],
+            }
+        )
+
+    def close(self) -> None:
+        pass
+
+
+class RowLargeValueStatefulProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle):
+        # Test all three state types with large values
+        value_state_schema = StructType([StructField("value", StringType(), 
True)])
+        self.value_state = handle.getValueState("valueState", 
value_state_schema)
+
+        list_state_schema = StructType([StructField("value", StringType(), 
True)])
+        self.list_state = handle.getListState("listState", list_state_schema)
+
+        self.map_state = handle.getMapState("mapState", "key string", "value 
string")
+
+    def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
+        # Create a large string (512 KB)
+        target_size_bytes = 512 * 1024
+        large_string = "a" * target_size_bytes
+
+        # Test ValueState with large string
+        self.value_state.update((large_string,))
+        value_retrieved = self.value_state.get()[0]
+
+        # Test ListState with large strings
+        self.list_state.put([(large_string,), (large_string + "b",), 
(large_string + "c",)])
+        list_retrieved = list(self.list_state.get())
+        list_elements = ",".join([elem[0] for elem in list_retrieved])
+
+        # Test MapState with large strings
+        map_key = ("large_string_key",)
+        self.map_state.updateValue(map_key, (large_string,))
+        map_retrieved = f"{map_key[0]}:{self.map_state.getValue(map_key)[0]}"
+
+        yield Row(
+            id=key[0],
+            valueStateResult=value_retrieved,
+            listStateResult=list_elements,
+            mapStateResult=map_retrieved,
+        )
+
+    def close(self) -> None:
+        pass
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 8359c08a43a6..6932ef28fc59 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -65,6 +65,7 @@ from 
pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import (
     ListStateLargeTTLProcessorFactory,
     MapStateProcessorFactory,
     MapStateLargeTTLProcessorFactory,
+    LargeValueStatefulProcessorFactory,
     BasicProcessorFactory,
     BasicProcessorNotNullableFactory,
     AddFieldsProcessorFactory,
@@ -2264,6 +2265,50 @@ class TransformWithStateTestsMixin:
                 ),
             )
 
+    # test all state types (value, list, map) with large values (512 KB)
+    def test_transform_with_state_large_values(self):
+        def check_results(batch_df, batch_id):
+            batch_df.collect()
+            # Create expected large string (512 KB)
+            target_size_bytes = 512 * 1024
+            large_string = "a" * target_size_bytes
+            expected_list_elements = ",".join(
+                [large_string, large_string + "b", large_string + "c"]
+            )
+            expected_map_result = f"large_string_key:{large_string}"
+
+            assert set(batch_df.sort("id").collect()) == {
+                Row(
+                    id="0",
+                    valueStateResult=large_string,
+                    listStateResult=expected_list_elements,
+                    mapStateResult=expected_map_result,
+                ),
+                Row(
+                    id="1",
+                    valueStateResult=large_string,
+                    listStateResult=expected_list_elements,
+                    mapStateResult=expected_map_result,
+                ),
+            }
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("valueStateResult", StringType(), True),
+                StructField("listStateResult", StringType(), True),
+                StructField("mapStateResult", StringType(), True),
+            ]
+        )
+
+        self._test_transform_with_state_basic(
+            LargeValueStatefulProcessorFactory(),
+            check_results,
+            True,
+            "None",
+            output_schema=output_schema,
+        )
+
 
 @unittest.skipIf(
     not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index 937b6232ee4a..4fee6a6e71d3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -208,7 +208,7 @@ class TransformWithStateInPySparkStateServer(
   private def parseProtoMessage(): StateRequest = {
     val messageLen = inputStream.readInt()
     val messageBytes = new Array[Byte](messageLen)
-    inputStream.read(messageBytes)
+    inputStream.readFully(messageBytes)
     StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to