This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6f47783d3512 [SPARK-51827][SS][CONNECT] Support Spark Connect on 
transformWithState in PySpark
6f47783d3512 is described below

commit 6f47783d3512c423d8f1b189d0b701b0e4f13415
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Mon Apr 28 08:34:11 2025 +0900

    [SPARK-51827][SS][CONNECT] Support Spark Connect on transformWithState in 
PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support Spark Connect on transformWithState in PySpark. 
The code is mostly reused between Pandas version and Row version.
    
    We rely on PythonEvanType to determine the user facing type of API, hence 
no proto change.
    
    ### Why are the changes needed?
    
    The new API needs to be supported with Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, we will expose a new API to be available in Spark Connect.
    
    ### How was this patch tested?
    
    New test suites.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50704 from 
HeartSaVioR/WIP-transform-with-state-python-in-spark-connect.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/connect/group.py                | 52 +++++++++++++++
 python/pyspark/sql/connect/plan.py                 | 21 ++++--
 .../test_parity_pandas_transform_with_state.py     | 32 +++++++++-
 .../sql/tests/test_connect_compatibility.py        |  3 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 74 ++++++++++++++++------
 5 files changed, 153 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 5a4888fda6db..ef0384cf8252 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -414,6 +414,58 @@ class GroupedData:
 
     transformWithStateInPandas.__doc__ = 
PySparkGroupedData.transformWithStateInPandas.__doc__
 
+    def transformWithState(
+        self,
+        statefulProcessor: StatefulProcessor,
+        outputStructType: Union[StructType, str],
+        outputMode: str,
+        timeMode: str,
+        initialState: Optional["GroupedData"] = None,
+        eventTimeColumnName: str = "",
+    ) -> "DataFrame":
+        from pyspark.sql.connect.udf import UserDefinedFunction
+        from pyspark.sql.connect.dataframe import DataFrame
+        from pyspark.sql.streaming.stateful_processor_util import (
+            TransformWithStateInPySparkUdfUtils,
+        )
+
+        udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, 
timeMode)
+        if initialState is None:
+            udf_obj = UserDefinedFunction(
+                udf_util.transformWithStateUDF,
+                returnType=outputStructType,
+                
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
+            )
+            initial_state_plan = None
+            initial_state_grouping_cols = None
+        else:
+            self._df._check_same_session(initialState._df)
+            udf_obj = UserDefinedFunction(
+                udf_util.transformWithStateWithInitStateUDF,
+                returnType=outputStructType,
+                
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
+            )
+            initial_state_plan = initialState._df._plan
+            initial_state_grouping_cols = initialState._grouping_cols
+
+        return DataFrame(
+            plan.TransformWithStateInPySpark(
+                child=self._df._plan,
+                grouping_cols=self._grouping_cols,
+                function=udf_obj,
+                output_schema=outputStructType,
+                output_mode=outputMode,
+                time_mode=timeMode,
+                event_time_col_name=eventTimeColumnName,
+                cols=self._df.columns,
+                initial_state_plan=initial_state_plan,
+                initial_state_grouping_cols=initial_state_grouping_cols,
+            ),
+            session=self._df._session,
+        )
+
+    transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__
+
     def applyInArrow(
         self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
     ) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index c4c7a6a63630..c5b6f5430d6d 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -2546,8 +2546,8 @@ class ApplyInPandasWithState(LogicalPlan):
         return self._with_relations(plan, session)
 
 
-class TransformWithStateInPandas(LogicalPlan):
-    """Logical plan object for a TransformWithStateInPandas."""
+class BaseTransformWithStateInPySpark(LogicalPlan):
+    """Base implementation of logical plan object for a 
TransformWithStateIn(PySpark/Pandas)."""
 
     def __init__(
         self,
@@ -2600,7 +2600,7 @@ class TransformWithStateInPandas(LogicalPlan):
                 [c.to_plan(session) for c in self._initial_state_grouping_cols]
             )
 
-        # fill in transformWithStateInPandas related fields
+        # fill in transformWithStateInPySpark/Pandas related fields
         tws_info = proto.TransformWithStateInfo()
         tws_info.time_mode = self._time_mode
         tws_info.event_time_column_name = self._event_time_col_name
@@ -2608,12 +2608,25 @@ class TransformWithStateInPandas(LogicalPlan):
 
         plan.group_map.transform_with_state_info.CopyFrom(tws_info)
 
-        # wrap transformWithStateInPandasUdf in a function
+        # wrap transformWithStateInPySparkUdf in a function
         plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
 
         return self._with_relations(plan, session)
 
 
+class TransformWithStateInPySpark(BaseTransformWithStateInPySpark):
+    """Logical plan object for a TransformWithStateInPySpark."""
+
+    pass
+
+
+# Retaining this to avoid breaking backward compatibility.
+class TransformWithStateInPandas(BaseTransformWithStateInPySpark):
+    """Logical plan object for a TransformWithStateInPandas."""
+
+    pass
+
+
 class PythonUDTF:
     """Represents a Python user-defined table function."""
 
diff --git 
a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
 
b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
index 26f2941d3d1f..e772c2139326 100644
--- 
a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
@@ -18,6 +18,7 @@ import unittest
 
 from pyspark.sql.tests.pandas.test_pandas_transform_with_state import (
     TransformWithStateInPandasTestsMixin,
+    TransformWithStateInPySparkTestsMixin,
 )
 from pyspark import SparkConf
 from pyspark.testing.connectutils import ReusedConnectTestCase
@@ -53,8 +54,35 @@ class TransformWithStateInPandasParityTests(
         pass
 
 
-# TODO(SPARK-51827): Need to copy the parity test when we implement 
transformWithState in
-#  Python Spark Connect
+class TransformWithStateInPySparkParityTests(
+    TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase
+):
+    """
+    Spark connect parity tests for TransformWithStateInPySpark. Run every test 
case in
+     `TransformWithStateInPySparkTestsMixin` in spark connect mode.
+    """
+
+    @classmethod
+    def conf(cls):
+        # Due to multiple inheritance from the same level, we need to 
explicitly setting configs in
+        # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase 
here
+        cfg = SparkConf(loadDefaults=False)
+        for base in cls.__bases__:
+            if hasattr(base, "conf"):
+                parent_cfg = base.conf()
+                for k, v in parent_cfg.getAll():
+                    cfg.set(k, v)
+
+        # Extra removing config for connect suites
+        if cfg._jconf is not None:
+            cfg._jconf.remove("spark.master")
+
+        return cfg
+
+    @unittest.skip("Flaky in spark connect on CI. Skip for now. See 
SPARK-51368 for details.")
+    def test_schema_evolution_scenarios(self):
+        pass
+
 
 if __name__ == "__main__":
     from 
pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import 
*  # noqa: F401,E501
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index 7323dc9424de..b2e0cc6229c4 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -395,8 +395,7 @@ class ConnectCompatibilityTestsMixin:
         """Test Grouping compatibility between classic and connect."""
         expected_missing_connect_properties = set()
         expected_missing_classic_properties = set()
-        # TODO(SPARK-51827): Add missing method `transformWithState` to the 
connect version
-        expected_missing_connect_methods = {"transformWithState"}
+        expected_missing_connect_methods = set()
         expected_missing_classic_methods = set()
         self.check_compatibility(
             ClassicGroupedData,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 911d79ecdb12..849dd9532405 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -661,7 +661,11 @@ class SparkConnectPlanner(
 
           case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF |
               PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF =>
-            transformTransformWithStateInPandas(pythonUdf, group, rel)
+            transformTransformWithStateInPySpark(pythonUdf, group, rel, 
usePandas = true)
+
+          case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF |
+              
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
+            transformTransformWithStateInPySpark(pythonUdf, group, rel, 
usePandas = false)
 
           case _ =>
             throw InvalidPlanInput(
@@ -1102,10 +1106,11 @@ class SparkConnectPlanner(
       .logicalPlan
   }
 
-  private def transformTransformWithStateInPandas(
+  private def transformTransformWithStateInPySpark(
       pythonUdf: PythonUDF,
       groupedDs: RelationalGroupedDataset,
-      rel: proto.GroupMap): LogicalPlan = {
+      rel: proto.GroupMap,
+      usePandas: Boolean): LogicalPlan = {
     val twsInfo = rel.getTransformWithStateInfo
     val outputSchema: StructType = {
       transformDataType(twsInfo.getOutputSchema) match {
@@ -1131,25 +1136,52 @@ class SparkConnectPlanner(
         .builder(groupedDs.df.logicalPlan.output)
         .asInstanceOf[PythonUDF]
 
-      groupedDs
-        .transformWithStateInPandas(
-          Column(resolvedPythonUDF),
-          outputSchema,
-          rel.getOutputMode,
-          twsInfo.getTimeMode,
-          initialStateDs,
-          twsInfo.getEventTimeColumnName)
-        .logicalPlan
+      if (usePandas) {
+        groupedDs
+          .transformWithStateInPandas(
+            Column(resolvedPythonUDF),
+            outputSchema,
+            rel.getOutputMode,
+            twsInfo.getTimeMode,
+            initialStateDs,
+            twsInfo.getEventTimeColumnName)
+          .logicalPlan
+      } else {
+        // use Row
+        groupedDs
+          .transformWithStateInPySpark(
+            Column(resolvedPythonUDF),
+            outputSchema,
+            rel.getOutputMode,
+            twsInfo.getTimeMode,
+            initialStateDs,
+            twsInfo.getEventTimeColumnName)
+          .logicalPlan
+      }
+
     } else {
-      groupedDs
-        .transformWithStateInPandas(
-          Column(pythonUdf),
-          outputSchema,
-          rel.getOutputMode,
-          twsInfo.getTimeMode,
-          null,
-          twsInfo.getEventTimeColumnName)
-        .logicalPlan
+      if (usePandas) {
+        groupedDs
+          .transformWithStateInPandas(
+            Column(pythonUdf),
+            outputSchema,
+            rel.getOutputMode,
+            twsInfo.getTimeMode,
+            null,
+            twsInfo.getEventTimeColumnName)
+          .logicalPlan
+      } else {
+        // use Row
+        groupedDs
+          .transformWithStateInPySpark(
+            Column(pythonUdf),
+            outputSchema,
+            rel.getOutputMode,
+            twsInfo.getTimeMode,
+            null,
+            twsInfo.getEventTimeColumnName)
+          .logicalPlan
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to