This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dad1369df70d [SPARK-52195][PYTHON][SS] Fix initial state column dropping issue for Python TWS dad1369df70d is described below commit dad1369df70d7b1e27610fada1f76d6455549c71 Author: bogao007 <bo....@databricks.com> AuthorDate: Wed May 21 10:54:03 2025 +0900 [SPARK-52195][PYTHON][SS] Fix initial state column dropping issue for Python TWS ### What changes were proposed in this pull request? Fix initial state column dropping issue for Python TWS. This may occur when user adds extra transformations after `TransformWithStateInPandas` operator and those initial state columns will get pruned during optimization. ### Why are the changes needed? This prevents users to use initial state with `TransformWithStateInPandas` if they require extra transformations. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50926 from bogao007/tws-column-dropping-fix. Authored-by: bogao007 <bo....@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../pandas/test_pandas_transform_with_state.py | 34 ++++++++++++++++++++++ .../plans/logical/pythonLogicalOperators.scala | 14 ++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index e36ae3a86a28..007ed5de2fbd 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -760,6 +760,7 @@ class TransformWithStateTestsMixin: time_mode="None", checkpoint_path=None, initial_state=None, + with_extra_transformation=False, ): input_path = tempfile.mkdtemp() if checkpoint_path is None: @@ -798,6 +799,14 @@ class TransformWithStateTestsMixin: initialState=initial_state, ) + if with_extra_transformation: + from pyspark.sql import functions as fn + + tws_df = tws_df.select( + fn.col("id").cast("string").alias("key"), + fn.to_json(fn.struct(fn.col("value"))).alias("value"), + ) + q = ( tws_df.writeStream.queryName("this_query") .option("checkpointLocation", checkpoint_path) @@ -835,6 +844,31 @@ class TransformWithStateTestsMixin: SimpleStatefulProcessorWithInitialStateFactory(), check_results ) + def test_transform_with_state_init_state_with_extra_transformation(self): + def check_results(batch_df, batch_id): + if batch_id == 0: + # for key 0, initial state was processed and it was only processed once; + # for key 1, it did not appear in the initial state df; + # for key 3, it did not appear in the first batch of input keys + # so it won't be emitted + assert set(batch_df.sort("key").collect()) == { + Row(key="0", value=f'{{"value":"{789 + 123 + 46}"}}'), + Row(key="1", value=f'{{"value":"{146 + 346}"}}'), + } + else: + # for key 0, verify initial state was only processed once in the first batch; + # for key 3, verify init state was processed and reflected in the accumulated value + assert set(batch_df.sort("key").collect()) == { + Row(key="0", value=f'{{"value":"{789 + 123 + 46 + 67}"}}'), + Row(key="3", value=f'{{"value":"{987 + 12}"}}'), + } + + self._test_transform_with_state_init_state( + SimpleStatefulProcessorWithInitialStateFactory(), + check_results, + with_extra_transformation=True, + ) + def _test_transform_with_state_non_contiguous_grouping_cols( self, stateful_processor_factory, check_results, initial_state=None ): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 8de801a8ffa1..0c0ea2443489 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -204,7 +204,7 @@ case class TransformWithStateInPySpark( override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) override lazy val references: AttributeSet = - AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) -- producedAttributes + AttributeSet(leftAttributes ++ rightReferences ++ functionExpr.references) -- producedAttributes override protected def withNewChildrenInternal( newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPySpark = @@ -225,6 +225,18 @@ case class TransformWithStateInPySpark( left.output.take(groupingAttributesLen) } } + + // Include the initial state columns in the references to avoid being column pruned. + private def rightReferences: Seq[Attribute] = { + assert(resolved, "This method is expected to be called after resolution.") + if (hasInitialState) { + right.output + } else { + // Dummy variables for passing the distribution & ordering check + // in physical operators. + left.output.take(groupingAttributesLen) + } + } } object TransformWithStateInPySpark { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org