This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dad1369df70d [SPARK-52195][PYTHON][SS] Fix initial state column 
dropping issue for Python TWS
dad1369df70d is described below

commit dad1369df70d7b1e27610fada1f76d6455549c71
Author: bogao007 <bo....@databricks.com>
AuthorDate: Wed May 21 10:54:03 2025 +0900

    [SPARK-52195][PYTHON][SS] Fix initial state column dropping issue for 
Python TWS
    
    ### What changes were proposed in this pull request?
    
    Fix initial state column dropping issue for Python TWS. This may occur when 
user adds extra transformations after `TransformWithStateInPandas` operator and 
those initial state columns will get pruned during optimization.
    
    ### Why are the changes needed?
    
    This prevents users to use initial state with `TransformWithStateInPandas` 
if they require extra transformations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit test case.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50926 from bogao007/tws-column-dropping-fix.
    
    Authored-by: bogao007 <bo....@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../pandas/test_pandas_transform_with_state.py     | 34 ++++++++++++++++++++++
 .../plans/logical/pythonLogicalOperators.scala     | 14 ++++++++-
 2 files changed, 47 insertions(+), 1 deletion(-)

diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index e36ae3a86a28..007ed5de2fbd 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -760,6 +760,7 @@ class TransformWithStateTestsMixin:
         time_mode="None",
         checkpoint_path=None,
         initial_state=None,
+        with_extra_transformation=False,
     ):
         input_path = tempfile.mkdtemp()
         if checkpoint_path is None:
@@ -798,6 +799,14 @@ class TransformWithStateTestsMixin:
                 initialState=initial_state,
             )
 
+        if with_extra_transformation:
+            from pyspark.sql import functions as fn
+
+            tws_df = tws_df.select(
+                fn.col("id").cast("string").alias("key"),
+                fn.to_json(fn.struct(fn.col("value"))).alias("value"),
+            )
+
         q = (
             tws_df.writeStream.queryName("this_query")
             .option("checkpointLocation", checkpoint_path)
@@ -835,6 +844,31 @@ class TransformWithStateTestsMixin:
             SimpleStatefulProcessorWithInitialStateFactory(), check_results
         )
 
+    def test_transform_with_state_init_state_with_extra_transformation(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                # for key 0, initial state was processed and it was only 
processed once;
+                # for key 1, it did not appear in the initial state df;
+                # for key 3, it did not appear in the first batch of input keys
+                # so it won't be emitted
+                assert set(batch_df.sort("key").collect()) == {
+                    Row(key="0", value=f'{{"value":"{789 + 123 + 46}"}}'),
+                    Row(key="1", value=f'{{"value":"{146 + 346}"}}'),
+                }
+            else:
+                # for key 0, verify initial state was only processed once in 
the first batch;
+                # for key 3, verify init state was processed and reflected in 
the accumulated value
+                assert set(batch_df.sort("key").collect()) == {
+                    Row(key="0", value=f'{{"value":"{789 + 123 + 46 + 67}"}}'),
+                    Row(key="3", value=f'{{"value":"{987 + 12}"}}'),
+                }
+
+        self._test_transform_with_state_init_state(
+            SimpleStatefulProcessorWithInitialStateFactory(),
+            check_results,
+            with_extra_transformation=True,
+        )
+
     def _test_transform_with_state_non_contiguous_grouping_cols(
         self, stateful_processor_factory, check_results, initial_state=None
     ):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 8de801a8ffa1..0c0ea2443489 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -204,7 +204,7 @@ case class TransformWithStateInPySpark(
   override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
 
   override lazy val references: AttributeSet =
-    AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) 
-- producedAttributes
+    AttributeSet(leftAttributes ++ rightReferences ++ functionExpr.references) 
-- producedAttributes
 
   override protected def withNewChildrenInternal(
       newLeft: LogicalPlan, newRight: LogicalPlan): 
TransformWithStateInPySpark =
@@ -225,6 +225,18 @@ case class TransformWithStateInPySpark(
       left.output.take(groupingAttributesLen)
     }
   }
+
+  // Include the initial state columns in the references to avoid being column 
pruned.
+  private def rightReferences: Seq[Attribute] = {
+    assert(resolved, "This method is expected to be called after resolution.")
+    if (hasInitialState) {
+      right.output
+    } else {
+      // Dummy variables for passing the distribution & ordering check
+      // in physical operators.
+      left.output.take(groupingAttributesLen)
+    }
+  }
 }
 
 object TransformWithStateInPySpark {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to