This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 092223365928 [SPARK-52126][SS][PYTHON] Revert rename for TWS utility classes for forward compatibility in Spark Connect 092223365928 is described below commit 092223365928d03014db6a65cfe9c28000e3c4d1 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Wed May 14 17:03:27 2025 +0900 [SPARK-52126][SS][PYTHON] Revert rename for TWS utility classes for forward compatibility in Spark Connect ### What changes were proposed in this pull request? This PR proposes to revert the renaming of TWS utility classes since it broke forward compatibility in Spark Connect. We sought to find the way to leave the refactor and make the new version to understand the old class, but there is Enumeration which does not support inheritance unfortunately, and to support both Enumeration classes it needs broader code change. ### Why are the changes needed? Without this fix, Spark 4.0 client cannot work with Spark 4.1 server in Spark Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. W.r.t. compatibility test, we tested manually on internal system test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50883 from HeartSaVioR/SPARK-52126. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/connect/group.py | 8 +++--- python/pyspark/sql/pandas/group_ops.py | 4 +-- python/pyspark/sql/pandas/serializers.py | 32 +++++++++++----------- .../sql/streaming/stateful_processor_util.py | 28 +++++++++++-------- .../transform_with_state_driver_worker.py | 6 ++-- python/pyspark/worker.py | 12 ++++---- 6 files changed, 47 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index ef0384cf8252..8e1d46d66abd 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -374,10 +374,10 @@ class GroupedData: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkUdfUtils, + TransformWithStateInPandasUdfUtils, ) - udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode) + udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode) if initialState is None: udf_obj = UserDefinedFunction( udf_util.transformWithStateUDF, @@ -426,10 +426,10 @@ class GroupedData: from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkUdfUtils, + TransformWithStateInPandasUdfUtils, ) - udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode) + udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode) if initialState is None: udf_obj = UserDefinedFunction( udf_util.transformWithStateUDF, diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 5fe711f742ce..08e795e99a1d 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -635,7 +635,7 @@ class PandasGroupedOpsMixin: from pyspark.sql import GroupedData from pyspark.sql.functions import pandas_udf from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkUdfUtils, + TransformWithStateInPandasUdfUtils, ) assert isinstance(self, GroupedData) @@ -645,7 +645,7 @@ class PandasGroupedOpsMixin: outputStructType = cast(StructType, self._df._session._parse_ddl(outputStructType)) df = self._df - udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode) + udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode) # explicitly set the type to Any since it could match to various types (literals) functionType: Any = None diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 47a8f5df37e6..b3fa77f8ba20 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1258,7 +1258,7 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): """ import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkFuncMode, + TransformWithStateInPandasFuncMode, ) def generate_data_batches(batches): @@ -1284,11 +1284,11 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer): data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) - yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) def dump_stream(self, iterator, stream): """ @@ -1326,7 +1326,7 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe def load_stream(self, stream): import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkFuncMode, + TransformWithStateInPandasFuncMode, ) def generate_data_batches(batches): @@ -1388,11 +1388,11 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe data_batches = generate_data_batches(_batches) for k, g in groupby(data_batches, key=lambda x: x[0]): - yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g) - yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): @@ -1420,7 +1420,7 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): this function works in overall. """ from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkFuncMode, + TransformWithStateInPandasFuncMode, ) import itertools @@ -1451,11 +1451,11 @@ class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer): for k, g in groupby(data_batches, key=lambda x: x[0]): chained = itertools.chain(g) chained_values = map(lambda x: x[1], chained) - yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, chained_values) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, chained_values) - yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) def dump_stream(self, iterator, stream): """ @@ -1503,7 +1503,7 @@ class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp import itertools import pyarrow as pa from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkFuncMode, + TransformWithStateInPandasFuncMode, ) def generate_data_batches(batches): @@ -1592,8 +1592,8 @@ class TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySp ret_tuple = (chained_input_values_without_none, chained_init_state_values_without_none) - yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, ret_tuple) + yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, ret_tuple) - yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None) + yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None) - yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None) + yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None) diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py b/python/pyspark/sql/streaming/stateful_processor_util.py index c0ff176eb9c9..d13d32403ffe 100644 --- a/python/pyspark/sql/streaming/stateful_processor_util.py +++ b/python/pyspark/sql/streaming/stateful_processor_util.py @@ -38,10 +38,12 @@ if TYPE_CHECKING: # contains public APIs. -class TransformWithStateInPySparkFuncMode(Enum): +class TransformWithStateInPandasFuncMode(Enum): """ Internal mode for python worker UDF mode for transformWithState in PySpark; external mode are in `StatefulProcessorHandleState` for public use purposes. + + NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect. """ PROCESS_DATA = 1 @@ -50,10 +52,12 @@ class TransformWithStateInPySparkFuncMode(Enum): PRE_INIT = 4 -class TransformWithStateInPySparkUdfUtils: +class TransformWithStateInPandasUdfUtils: """ Internal Utility class used for python worker UDF for transformWithState in PySpark. This class is shared for both classic and spark connect mode. + + NOTE: The class has `Pandas` in its name for compatibility purposes in Spark Connect. """ def __init__(self, stateful_processor: StatefulProcessor, time_mode: str): @@ -63,11 +67,11 @@ class TransformWithStateInPySparkUdfUtils: def transformWithStateUDF( self, stateful_processor_api_client: StatefulProcessorApiClient, - mode: TransformWithStateInPySparkFuncMode, + mode: TransformWithStateInPandasFuncMode, key: Any, input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]], ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]: - if mode == TransformWithStateInPySparkFuncMode.PRE_INIT: + if mode == TransformWithStateInPandasFuncMode.PRE_INIT: return self._handle_pre_init(stateful_processor_api_client) handle = StatefulProcessorHandle(stateful_processor_api_client) @@ -76,13 +80,13 @@ class TransformWithStateInPySparkUdfUtils: self._stateful_processor.init(handle) stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED) - if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: stateful_processor_api_client.set_handle_state( StatefulProcessorHandleState.DATA_PROCESSED ) result = self._handle_expired_timers(stateful_processor_api_client) return result - elif mode == TransformWithStateInPySparkFuncMode.COMPLETE: + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: stateful_processor_api_client.set_handle_state( StatefulProcessorHandleState.TIMER_PROCESSED ) @@ -91,14 +95,14 @@ class TransformWithStateInPySparkUdfUtils: stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) else: - # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA result = self._handle_data_rows(stateful_processor_api_client, key, input_rows) return result def transformWithStateWithInitStateUDF( self, stateful_processor_api_client: StatefulProcessorApiClient, - mode: TransformWithStateInPySparkFuncMode, + mode: TransformWithStateInPandasFuncMode, key: Any, input_rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]], initial_states: Optional[Union[Iterator["PandasDataFrameLike"], Iterator[Row]]] = None, @@ -115,7 +119,7 @@ class TransformWithStateInPySparkUdfUtils: - `initialStates` is None, while `inputRows` is not empty. This is not first batch. `initialStates` is initialized to the positional value as None. """ - if mode == TransformWithStateInPySparkFuncMode.PRE_INIT: + if mode == TransformWithStateInPandasFuncMode.PRE_INIT: return self._handle_pre_init(stateful_processor_api_client) handle = StatefulProcessorHandle(stateful_processor_api_client) @@ -124,19 +128,19 @@ class TransformWithStateInPySparkUdfUtils: self._stateful_processor.init(handle) stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.INITIALIZED) - if mode == TransformWithStateInPySparkFuncMode.PROCESS_TIMER: + if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER: stateful_processor_api_client.set_handle_state( StatefulProcessorHandleState.DATA_PROCESSED ) result = self._handle_expired_timers(stateful_processor_api_client) return result - elif mode == TransformWithStateInPySparkFuncMode.COMPLETE: + elif mode == TransformWithStateInPandasFuncMode.COMPLETE: stateful_processor_api_client.remove_implicit_key() self._stateful_processor.close() stateful_processor_api_client.set_handle_state(StatefulProcessorHandleState.CLOSED) return iter([]) else: - # mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA + # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA batch_timestamp, watermark_timestamp = stateful_processor_api_client.get_timestamps( self._time_mode ) diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py index 8d9bed7e6187..3fe7f68a99e5 100644 --- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py +++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py @@ -31,7 +31,7 @@ from pyspark.util import handle_worker_exception from typing import IO from pyspark.worker_util import check_python_version from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPySparkFuncMode +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.sql.types import StructType if TYPE_CHECKING: @@ -51,7 +51,7 @@ def main(infile: IO, outfile: IO) -> None: def process( processor: StatefulProcessorApiClient, - mode: TransformWithStateInPySparkFuncMode, + mode: TransformWithStateInPandasFuncMode, key: Any, input: Iterator["PandasDataFrameLike"], ) -> None: @@ -83,7 +83,7 @@ def main(infile: IO, outfile: IO) -> None: stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) process( stateful_processor_api_client, - TransformWithStateInPySparkFuncMode.PRE_INIT, + TransformWithStateInPandasFuncMode.PRE_INIT, None, iter([]), ) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 15b4641ce838..18da7bcc7704 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,7 +34,7 @@ from pyspark.accumulators import ( _deserialize_accumulator, ) from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPySparkFuncMode +from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasFuncMode from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation from pyspark.util import PythonEvalType, local_connect_and_auth @@ -589,7 +589,7 @@ def wrap_grouped_transform_with_state_udf(f, return_type, runner_conf): def wrap_grouped_transform_with_state_init_state_udf(f, return_type, runner_conf): def wrapped(stateful_processor_api_client, mode, key, values): - if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA: + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: values_gen = values[0] init_states_gen = values[1] else: @@ -2001,7 +2001,7 @@ def read_udfs(pickleSer, infile, eval_type): def mapper(a): mode = a[0] - if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA: + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] def values_gen(): @@ -2038,7 +2038,7 @@ def read_udfs(pickleSer, infile, eval_type): def mapper(a): mode = a[0] - if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA: + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] def values_gen(): @@ -2070,7 +2070,7 @@ def read_udfs(pickleSer, infile, eval_type): def mapper(a): mode = a[0] - if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA: + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] values = a[2] @@ -2103,7 +2103,7 @@ def read_udfs(pickleSer, infile, eval_type): def mapper(a): mode = a[0] - if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA: + if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA: key = a[1] values = a[2] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org