This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6f47783d3512 [SPARK-51827][SS][CONNECT] Support Spark Connect on transformWithState in PySpark 6f47783d3512 is described below commit 6f47783d3512c423d8f1b189d0b701b0e4f13415 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Mon Apr 28 08:34:11 2025 +0900 [SPARK-51827][SS][CONNECT] Support Spark Connect on transformWithState in PySpark ### What changes were proposed in this pull request? This PR proposes to support Spark Connect on transformWithState in PySpark. The code is mostly reused between Pandas version and Row version. We rely on PythonEvanType to determine the user facing type of API, hence no proto change. ### Why are the changes needed? The new API needs to be supported with Spark Connect. ### Does this PR introduce _any_ user-facing change? Yes, we will expose a new API to be available in Spark Connect. ### How was this patch tested? New test suites. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50704 from HeartSaVioR/WIP-transform-with-state-python-in-spark-connect. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/connect/group.py | 52 +++++++++++++++ python/pyspark/sql/connect/plan.py | 21 ++++-- .../test_parity_pandas_transform_with_state.py | 32 +++++++++- .../sql/tests/test_connect_compatibility.py | 3 +- .../sql/connect/planner/SparkConnectPlanner.scala | 74 ++++++++++++++++------ 5 files changed, 153 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 5a4888fda6db..ef0384cf8252 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -414,6 +414,58 @@ class GroupedData: transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__ + def transformWithState( + self, + statefulProcessor: StatefulProcessor, + outputStructType: Union[StructType, str], + outputMode: str, + timeMode: str, + initialState: Optional["GroupedData"] = None, + eventTimeColumnName: str = "", + ) -> "DataFrame": + from pyspark.sql.connect.udf import UserDefinedFunction + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.streaming.stateful_processor_util import ( + TransformWithStateInPySparkUdfUtils, + ) + + udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode) + if initialState is None: + udf_obj = UserDefinedFunction( + udf_util.transformWithStateUDF, + returnType=outputStructType, + evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, + ) + initial_state_plan = None + initial_state_grouping_cols = None + else: + self._df._check_same_session(initialState._df) + udf_obj = UserDefinedFunction( + udf_util.transformWithStateWithInitStateUDF, + returnType=outputStructType, + evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, + ) + initial_state_plan = initialState._df._plan + initial_state_grouping_cols = initialState._grouping_cols + + return DataFrame( + plan.TransformWithStateInPySpark( + child=self._df._plan, + grouping_cols=self._grouping_cols, + function=udf_obj, + output_schema=outputStructType, + output_mode=outputMode, + time_mode=timeMode, + event_time_col_name=eventTimeColumnName, + cols=self._df.columns, + initial_state_plan=initial_state_plan, + initial_state_grouping_cols=initial_state_grouping_cols, + ), + session=self._df._session, + ) + + transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__ + def applyInArrow( self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] ) -> "DataFrame": diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index c4c7a6a63630..c5b6f5430d6d 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -2546,8 +2546,8 @@ class ApplyInPandasWithState(LogicalPlan): return self._with_relations(plan, session) -class TransformWithStateInPandas(LogicalPlan): - """Logical plan object for a TransformWithStateInPandas.""" +class BaseTransformWithStateInPySpark(LogicalPlan): + """Base implementation of logical plan object for a TransformWithStateIn(PySpark/Pandas).""" def __init__( self, @@ -2600,7 +2600,7 @@ class TransformWithStateInPandas(LogicalPlan): [c.to_plan(session) for c in self._initial_state_grouping_cols] ) - # fill in transformWithStateInPandas related fields + # fill in transformWithStateInPySpark/Pandas related fields tws_info = proto.TransformWithStateInfo() tws_info.time_mode = self._time_mode tws_info.event_time_column_name = self._event_time_col_name @@ -2608,12 +2608,25 @@ class TransformWithStateInPandas(LogicalPlan): plan.group_map.transform_with_state_info.CopyFrom(tws_info) - # wrap transformWithStateInPandasUdf in a function + # wrap transformWithStateInPySparkUdf in a function plan.group_map.func.CopyFrom(self._function.to_plan_udf(session)) return self._with_relations(plan, session) +class TransformWithStateInPySpark(BaseTransformWithStateInPySpark): + """Logical plan object for a TransformWithStateInPySpark.""" + + pass + + +# Retaining this to avoid breaking backward compatibility. +class TransformWithStateInPandas(BaseTransformWithStateInPySpark): + """Logical plan object for a TransformWithStateInPandas.""" + + pass + + class PythonUDTF: """Represents a Python user-defined table function.""" diff --git a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py index 26f2941d3d1f..e772c2139326 100644 --- a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py @@ -18,6 +18,7 @@ import unittest from pyspark.sql.tests.pandas.test_pandas_transform_with_state import ( TransformWithStateInPandasTestsMixin, + TransformWithStateInPySparkTestsMixin, ) from pyspark import SparkConf from pyspark.testing.connectutils import ReusedConnectTestCase @@ -53,8 +54,35 @@ class TransformWithStateInPandasParityTests( pass -# TODO(SPARK-51827): Need to copy the parity test when we implement transformWithState in -# Python Spark Connect +class TransformWithStateInPySparkParityTests( + TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase +): + """ + Spark connect parity tests for TransformWithStateInPySpark. Run every test case in + `TransformWithStateInPySparkTestsMixin` in spark connect mode. + """ + + @classmethod + def conf(cls): + # Due to multiple inheritance from the same level, we need to explicitly setting configs in + # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase here + cfg = SparkConf(loadDefaults=False) + for base in cls.__bases__: + if hasattr(base, "conf"): + parent_cfg = base.conf() + for k, v in parent_cfg.getAll(): + cfg.set(k, v) + + # Extra removing config for connect suites + if cfg._jconf is not None: + cfg._jconf.remove("spark.master") + + return cfg + + @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.") + def test_schema_evolution_scenarios(self): + pass + if __name__ == "__main__": from pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import * # noqa: F401,E501 diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 7323dc9424de..b2e0cc6229c4 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -395,8 +395,7 @@ class ConnectCompatibilityTestsMixin: """Test Grouping compatibility between classic and connect.""" expected_missing_connect_properties = set() expected_missing_classic_properties = set() - # TODO(SPARK-51827): Add missing method `transformWithState` to the connect version - expected_missing_connect_methods = {"transformWithState"} + expected_missing_connect_methods = set() expected_missing_classic_methods = set() self.check_compatibility( ClassicGroupedData, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 911d79ecdb12..849dd9532405 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -661,7 +661,11 @@ class SparkConnectPlanner( case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF | PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF => - transformTransformWithStateInPandas(pythonUdf, group, rel) + transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = true) + + case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF | + PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF => + transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = false) case _ => throw InvalidPlanInput( @@ -1102,10 +1106,11 @@ class SparkConnectPlanner( .logicalPlan } - private def transformTransformWithStateInPandas( + private def transformTransformWithStateInPySpark( pythonUdf: PythonUDF, groupedDs: RelationalGroupedDataset, - rel: proto.GroupMap): LogicalPlan = { + rel: proto.GroupMap, + usePandas: Boolean): LogicalPlan = { val twsInfo = rel.getTransformWithStateInfo val outputSchema: StructType = { transformDataType(twsInfo.getOutputSchema) match { @@ -1131,25 +1136,52 @@ class SparkConnectPlanner( .builder(groupedDs.df.logicalPlan.output) .asInstanceOf[PythonUDF] - groupedDs - .transformWithStateInPandas( - Column(resolvedPythonUDF), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - initialStateDs, - twsInfo.getEventTimeColumnName) - .logicalPlan + if (usePandas) { + groupedDs + .transformWithStateInPandas( + Column(resolvedPythonUDF), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + initialStateDs, + twsInfo.getEventTimeColumnName) + .logicalPlan + } else { + // use Row + groupedDs + .transformWithStateInPySpark( + Column(resolvedPythonUDF), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + initialStateDs, + twsInfo.getEventTimeColumnName) + .logicalPlan + } + } else { - groupedDs - .transformWithStateInPandas( - Column(pythonUdf), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - null, - twsInfo.getEventTimeColumnName) - .logicalPlan + if (usePandas) { + groupedDs + .transformWithStateInPandas( + Column(pythonUdf), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + null, + twsInfo.getEventTimeColumnName) + .logicalPlan + } else { + // use Row + groupedDs + .transformWithStateInPySpark( + Column(pythonUdf), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + null, + twsInfo.getEventTimeColumnName) + .logicalPlan + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org