This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 092223365928 [SPARK-52126][SS][PYTHON] Revert rename for TWS utility 
classes for forward compatibility in Spark Connect
092223365928 is described below

commit 092223365928d03014db6a65cfe9c28000e3c4d1
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Wed May 14 17:03:27 2025 +0900

    [SPARK-52126][SS][PYTHON] Revert rename for TWS utility classes for forward 
compatibility in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to revert the renaming of TWS utility classes since it 
broke forward compatibility in Spark Connect.
    
    We sought to find the way to leave the refactor and make the new version to 
understand the old class, but there is Enumeration which does not support 
inheritance unfortunately, and to support both Enumeration classes it needs 
broader code change.
    
    ### Why are the changes needed?
    
    Without this fix, Spark 4.0 client cannot work with Spark 4.1 server in 
Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing UTs. W.r.t. compatibility test, we tested manually on internal 
system test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50883 from HeartSaVioR/SPARK-52126.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/connect/group.py                |  8 +++---
 python/pyspark/sql/pandas/group_ops.py             |  4 +--
 python/pyspark/sql/pandas/serializers.py           | 32 +++++++++++-----------
 .../sql/streaming/stateful_processor_util.py       | 28 +++++++++++--------
 .../transform_with_state_driver_worker.py          |  6 ++--
 python/pyspark/worker.py                           | 12 ++++----
 6 files changed, 47 insertions(+), 43 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index ef0384cf8252..8e1d46d66abd 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -374,10 +374,10 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkUdfUtils,
+            TransformWithStateInPandasUdfUtils,
         )
 
-        udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, 
timeMode)
+        udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, 
timeMode)
         if initialState is None:
             udf_obj = UserDefinedFunction(
                 udf_util.transformWithStateUDF,
@@ -426,10 +426,10 @@ class GroupedData:
         from pyspark.sql.connect.udf import UserDefinedFunction
         from pyspark.sql.connect.dataframe import DataFrame
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkUdfUtils,
+            TransformWithStateInPandasUdfUtils,
         )
 
-        udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, 
timeMode)
+        udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, 
timeMode)
         if initialState is None:
             udf_obj = UserDefinedFunction(
                 udf_util.transformWithStateUDF,
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 5fe711f742ce..08e795e99a1d 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -635,7 +635,7 @@ class PandasGroupedOpsMixin:
         from pyspark.sql import GroupedData
         from pyspark.sql.functions import pandas_udf
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkUdfUtils,
+            TransformWithStateInPandasUdfUtils,
         )
 
         assert isinstance(self, GroupedData)
@@ -645,7 +645,7 @@ class PandasGroupedOpsMixin:
             outputStructType = cast(StructType, 
self._df._session._parse_ddl(outputStructType))
 
         df = self._df
-        udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, 
timeMode)
+        udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, 
timeMode)
 
         # explicitly set the type to Any since it could match to various types 
(literals)
         functionType: Any = None
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 47a8f5df37e6..b3fa77f8ba20 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1258,7 +1258,7 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         """
         import pyarrow as pa
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkFuncMode,
+            TransformWithStateInPandasFuncMode,
         )
 
         def generate_data_batches(batches):
@@ -1284,11 +1284,11 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
 
-        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
 
-        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
 
     def dump_stream(self, iterator, stream):
         """
@@ -1326,7 +1326,7 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
     def load_stream(self, stream):
         import pyarrow as pa
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkFuncMode,
+            TransformWithStateInPandasFuncMode,
         )
 
         def generate_data_batches(batches):
@@ -1388,11 +1388,11 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
 
-        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
 
-        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
 
 
 class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
@@ -1420,7 +1420,7 @@ class 
TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
         this function works in overall.
         """
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkFuncMode,
+            TransformWithStateInPandasFuncMode,
         )
         import itertools
 
@@ -1451,11 +1451,11 @@ class 
TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
         for k, g in groupby(data_batches, key=lambda x: x[0]):
             chained = itertools.chain(g)
             chained_values = map(lambda x: x[1], chained)
-            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, 
chained_values)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, 
chained_values)
 
-        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
 
-        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
 
     def dump_stream(self, iterator, stream):
         """
@@ -1503,7 +1503,7 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
         import itertools
         import pyarrow as pa
         from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkFuncMode,
+            TransformWithStateInPandasFuncMode,
         )
 
         def generate_data_batches(batches):
@@ -1592,8 +1592,8 @@ class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp
 
             ret_tuple = (chained_input_values_without_none, 
chained_init_state_values_without_none)
 
-            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, 
ret_tuple)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, 
ret_tuple)
 
-        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
 
-        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py 
b/python/pyspark/sql/streaming/stateful_processor_util.py
index c0ff176eb9c9..d13d32403ffe 100644
--- a/python/pyspark/sql/streaming/stateful_processor_util.py
+++ b/python/pyspark/sql/streaming/stateful_processor_util.py
@@ -38,10 +38,12 @@ if TYPE_CHECKING:
 # contains public APIs.
 
 
-class TransformWithStateInPySparkFuncMode(Enum):
+class TransformWithStateInPandasFuncMode(Enum):
     """
     Internal mode for python worker UDF mode for transformWithState in 
PySpark; external mode are
     in `StatefulProcessorHandleState` for public use purposes.
+
+    NOTE: The class has `Pandas` in its name for compatibility purposes in 
Spark Connect.
     """
 
     PROCESS_DATA = 1
@@ -50,10 +52,12 @@ class TransformWithStateInPySparkFuncMode(Enum):
     PRE_INIT = 4
 
 
-class TransformWithStateInPySparkUdfUtils:
+class TransformWithStateInPandasUdfUtils:
     """
     Internal Utility class used for python worker UDF for transformWithState 
in PySpark. This class
     is shared for both classic and spark connect mode.
+
+    NOTE: The class has `Pandas` in its name for compatibility purposes in 
Spark Connect.
     """
 
     def __init__(self, stateful_processor: StatefulProcessor, time_mode: str):
@@ -63,11 +67,11 @@ class TransformWithStateInPySparkUdfUtils:
     def transformWithStateUDF(
         self,
         stateful_processor_api_client: StatefulProcessorApiClient,
-        mode: TransformWithStateInPySparkFuncMode,
+        mode: TransformWithStateInPandasFuncMode,
         key: Any,
         input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
     ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
-        if mode == TransformWithStateInPySparkFuncMode.PRE_INIT:
+        if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
             return self._handle_pre_init(stateful_processor_api_client)
 
         handle = StatefulProcessorHandle(stateful_processor_api_client)
@@ -76,13 +80,13 @@ class TransformWithStateInPySparkUdfUtils:
             self._stateful_processor.init(handle)
             
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
 
-        if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER:
+        if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
             stateful_processor_api_client.set_handle_state(
                 StatefulProcessorHandleState.DATA_PROCESSED
             )
             result = self._handle_expired_timers(stateful_processor_api_client)
             return result
-        elif mode == TransformWithStateInPySparkFuncMode.COMPLETE:
+        elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
             stateful_processor_api_client.set_handle_state(
                 StatefulProcessorHandleState.TIMER_PROCESSED
             )
@@ -91,14 +95,14 @@ class TransformWithStateInPySparkUdfUtils:
             
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
             return iter([])
         else:
-            # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA
+            # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
             result = self._handle_data_rows(stateful_processor_api_client, 
key, input_rows)
             return result
 
     def transformWithStateWithInitStateUDF(
         self,
         stateful_processor_api_client: StatefulProcessorApiClient,
-        mode: TransformWithStateInPySparkFuncMode,
+        mode: TransformWithStateInPandasFuncMode,
         key: Any,
         input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
         initial_states: Optional[Union[Iterator["PandasDataFrameLike"], 
Iterator[Row]]] = None,
@@ -115,7 +119,7 @@ class TransformWithStateInPySparkUdfUtils:
         - `initialStates` is None, while `inputRows` is not empty. This is not 
first batch.
          `initialStates` is initialized to the positional value as None.
         """
-        if mode == TransformWithStateInPySparkFuncMode.PRE_INIT:
+        if mode == TransformWithStateInPandasFuncMode.PRE_INIT:
             return self._handle_pre_init(stateful_processor_api_client)
 
         handle = StatefulProcessorHandle(stateful_processor_api_client)
@@ -124,19 +128,19 @@ class TransformWithStateInPySparkUdfUtils:
             self._stateful_processor.init(handle)
             
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED)
 
-        if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER:
+        if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
             stateful_processor_api_client.set_handle_state(
                 StatefulProcessorHandleState.DATA_PROCESSED
             )
             result = self._handle_expired_timers(stateful_processor_api_client)
             return result
-        elif mode == TransformWithStateInPySparkFuncMode.COMPLETE:
+        elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
             stateful_processor_api_client.remove_implicit_key()
             self._stateful_processor.close()
             
stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED)
             return iter([])
         else:
-            # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA
+            # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
             batch_timestamp, watermark_timestamp = 
stateful_processor_api_client.get_timestamps(
                 self._time_mode
             )
diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py 
b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
index 8d9bed7e6187..3fe7f68a99e5 100644
--- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
+++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
@@ -31,7 +31,7 @@ from pyspark.util import handle_worker_exception
 from typing import IO
 from pyspark.worker_util import check_python_version
 from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
-from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPySparkFuncMode
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.sql.types import StructType
 
 if TYPE_CHECKING:
@@ -51,7 +51,7 @@ def main(infile: IO, outfile: IO) -> None:
 
     def process(
         processor: StatefulProcessorApiClient,
-        mode: TransformWithStateInPySparkFuncMode,
+        mode: TransformWithStateInPandasFuncMode,
         key: Any,
         input: Iterator["PandasDataFrameLike"],
     ) -> None:
@@ -83,7 +83,7 @@ def main(infile: IO, outfile: IO) -> None:
         stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
         process(
             stateful_processor_api_client,
-            TransformWithStateInPySparkFuncMode.PRE_INIT,
+            TransformWithStateInPandasFuncMode.PRE_INIT,
             None,
             iter([]),
         )
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 15b4641ce838..18da7bcc7704 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,7 +34,7 @@ from pyspark.accumulators import (
     _deserialize_accumulator,
 )
 from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
-from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPySparkFuncMode
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
 from pyspark.resource import ResourceInformation
 from pyspark.util import PythonEvalType, local_connect_and_auth
@@ -589,7 +589,7 @@ def wrap_grouped_transform_with_state_udf(f, return_type, 
runner_conf):
 
 def wrap_grouped_transform_with_state_init_state_udf(f, return_type, 
runner_conf):
     def wrapped(stateful_processor_api_client, mode, key, values):
-        if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+        if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
             values_gen = values[0]
             init_states_gen = values[1]
         else:
@@ -2001,7 +2001,7 @@ def read_udfs(pickleSer, infile, eval_type):
         def mapper(a):
             mode = a[0]
 
-            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
                 key = a[1]
 
                 def values_gen():
@@ -2038,7 +2038,7 @@ def read_udfs(pickleSer, infile, eval_type):
         def mapper(a):
             mode = a[0]
 
-            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
                 key = a[1]
 
                 def values_gen():
@@ -2070,7 +2070,7 @@ def read_udfs(pickleSer, infile, eval_type):
         def mapper(a):
             mode = a[0]
 
-            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
                 key = a[1]
                 values = a[2]
 
@@ -2103,7 +2103,7 @@ def read_udfs(pickleSer, infile, eval_type):
         def mapper(a):
             mode = a[0]
 
-            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
                 key = a[1]
                 values = a[2]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to