This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e03319fd9219 [SPARK-49676][SS][PYTHON] Add Support for Chaining of 
Operators in transformWithStateInPandas API
e03319fd9219 is described below

commit e03319fd9219da7162c12a15998d5718edc4c49e
Author: jingz-db <[email protected]>
AuthorDate: Wed Nov 27 15:27:59 2024 +0900

    [SPARK-49676][SS][PYTHON] Add Support for Chaining of Operators in 
transformWithStateInPandas API
    
    ### What changes were proposed in this pull request?
    
    This PR adds support to define event time column in the output dataset of 
`TransformWithStateInPandas` operator. The new event time column will be used 
to evaluate watermark expressions in downstream operators.
    
    ### Why are the changes needed?
    
    This change is to couple with the scala implementation of chaining of 
operators. PR in Scala: https://github.com/apache/spark/pull/45376
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. User can now specify a event time column as:
    ```
    df.groupBy("id")
      .transformWithStateInPandas(
          statefulProcessor=stateful_processor,
          outputStructType=output_schema,
          outputMode="Update",
          timeMode=timeMode,
          eventTimeColumnName="outputTimestamp"
      )
    ```
    
    ### How was this patch tested?
    
    Integration tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48124 from jingz-db/python-chaining-op.
    
    Lead-authored-by: jingz-db <[email protected]>
    Co-authored-by: Jing Zhan <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/pyspark/sql/pandas/group_ops.py             |   2 +
 .../pandas/test_pandas_transform_with_state.py     | 158 ++++++++++++++++++---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   1 -
 .../analysis/UnsupportedOperationChecker.scala     |   1 +
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   3 +
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |   6 +-
 .../spark/sql/RelationalGroupedDataset.scala       |  30 +++-
 .../spark/sql/execution/SparkStrategies.scala      |  10 +-
 .../python/TransformWithStateInPandasExec.scala    |  45 +++++-
 .../execution/streaming/IncrementalExecution.scala |  17 +++
 .../streaming/TransformWithStateExec.scala         |   4 +-
 11 files changed, 245 insertions(+), 32 deletions(-)

diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 56efe0676c08..d8f22e434374 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -374,6 +374,7 @@ class PandasGroupedOpsMixin:
         outputMode: str,
         timeMode: str,
         initialState: Optional["GroupedData"] = None,
+        eventTimeColumnName: str = "",
     ) -> DataFrame:
         """
         Invokes methods defined in the stateful processor used in arbitrary 
state API v2. It
@@ -662,6 +663,7 @@ class PandasGroupedOpsMixin:
             outputMode,
             timeMode,
             initial_state_java_obj,
+            eventTimeColumnName,
         )
         return DataFrame(jdf, self.session)
 
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 514339249818..f385d7cd1abc 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -27,13 +27,7 @@ from typing import cast
 from pyspark import SparkConf
 from pyspark.errors import PySparkRuntimeError
 from pyspark.sql.functions import split
-from pyspark.sql.types import (
-    StringType,
-    StructType,
-    StructField,
-    Row,
-    IntegerType,
-)
+from pyspark.sql.types import StringType, StructType, StructField, Row, 
IntegerType, TimestampType
 from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -247,11 +241,15 @@ class TransformWithStateInPandasTestsMixin:
 
     # test list state with ttl has the same behavior as list state when state 
doesn't expire.
     def test_transform_with_state_in_pandas_list_state_large_ttl(self):
-        def check_results(batch_df, _):
-            assert set(batch_df.sort("id").collect()) == {
-                Row(id="0", countAsString="2"),
-                Row(id="1", countAsString="2"),
-            }
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="2"),
+                    Row(id="1", countAsString="2"),
+                }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         self._test_transform_with_state_in_pandas_basic(
             ListStateLargeTTLProcessor(), check_results, True, "processingTime"
@@ -268,11 +266,15 @@ class TransformWithStateInPandasTestsMixin:
 
     # test map state with ttl has the same behavior as map state when state 
doesn't expire.
     def test_transform_with_state_in_pandas_map_state_large_ttl(self):
-        def check_results(batch_df, _):
-            assert set(batch_df.sort("id").collect()) == {
-                Row(id="0", countAsString="2"),
-                Row(id="1", countAsString="2"),
-            }
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="2"),
+                    Row(id="1", countAsString="2"),
+                }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         self._test_transform_with_state_in_pandas_basic(
             MapStateLargeTTLProcessor(), check_results, True, "processingTime"
@@ -287,11 +289,14 @@ class TransformWithStateInPandasTestsMixin:
                     Row(id="0", countAsString="2"),
                     Row(id="1", countAsString="2"),
                 }
-            else:
+            elif batch_id == 1:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="3"),
                     Row(id="1", countAsString="2"),
                 }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         self._test_transform_with_state_in_pandas_basic(
             SimpleTTLStatefulProcessor(), check_results, False, 
"processingTime"
@@ -348,6 +353,9 @@ class TransformWithStateInPandasTestsMixin:
                         Row(id="ttl-map-state-count-1", count=3),
                     ],
                 )
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
             if batch_id == 0 or batch_id == 1:
                 time.sleep(6)
 
@@ -466,7 +474,7 @@ class TransformWithStateInPandasTestsMixin:
                 ).first()["timeValues"]
                 check_timestamp(batch_df)
 
-            else:
+            elif batch_id == 2:
                 assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
                     Row(id="0", countAsString="3"),
                     Row(id="0", countAsString="-1"),
@@ -480,6 +488,10 @@ class TransformWithStateInPandasTestsMixin:
                 ).first()["timeValues"]
                 assert current_batch_expired_timestamp > 
self.first_expired_timestamp
 
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
+
         self._test_transform_with_state_in_pandas_proc_timer(
             ProcTimeStatefulProcessor(), check_results
         )
@@ -552,12 +564,15 @@ class TransformWithStateInPandasTestsMixin:
                     Row(id="a", timestamp="20"),
                     Row(id="a-expired", timestamp="0"),
                 }
-            else:
+            elif batch_id == 2:
                 # verify that rows and expired timer produce the expected 
result
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="a", timestamp="15"),
                     Row(id="a-expired", timestamp="10000"),
                 }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         self._test_transform_with_state_in_pandas_event_time(
             EventTimeStatefulProcessor(), check_results
@@ -679,6 +694,9 @@ class TransformWithStateInPandasTestsMixin:
                     Row(id1="0", id2="1", value=str(123 + 46)),
                     Row(id1="1", id2="2", value=str(146 + 346)),
                 }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         self._test_transform_with_state_non_contiguous_grouping_cols(
             SimpleStatefulProcessorWithInitialState(), check_results
@@ -692,6 +710,9 @@ class TransformWithStateInPandasTestsMixin:
                     Row(id1="0", id2="1", value=str(789 + 123 + 46)),
                     Row(id1="1", id2="2", value=str(146 + 346)),
                 }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
 
         # grouping key of initial state is also not starting from the 
beginning of attributes
         data = [(789, "0", "1"), (987, "3", "2")]
@@ -703,6 +724,88 @@ class TransformWithStateInPandasTestsMixin:
             SimpleStatefulProcessorWithInitialState(), check_results, 
initial_state
         )
 
+    def _test_transform_with_state_in_pandas_chaining_ops(
+        self, stateful_processor, check_results, timeMode="None", 
grouping_cols=["outputTimestamp"]
+    ):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+        self._prepare_input_data(input_path + "/text-test3.txt", ["a", "b"], 
[10, 15])
+        time.sleep(2)
+        self._prepare_input_data(input_path + "/text-test4.txt", ["a", "c"], 
[11, 25])
+        time.sleep(2)
+        self._prepare_input_data(input_path + "/text-test1.txt", ["a"], [5])
+
+        df = self._build_test_df(input_path)
+        df = df.select(
+            "id", 
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
+        ).withWatermark("eventTime", "5 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("outputTimestamp", TimestampType(), True),
+            ]
+        )
+
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Append",
+                timeMode=timeMode,
+                eventTimeColumnName="outputTimestamp",
+            )
+            .groupBy(grouping_cols)
+            .count()
+            .writeStream.queryName("chaining_ops_query")
+            .foreachBatch(check_results)
+            .outputMode("append")
+            .start()
+        )
+
+        self.assertEqual(q.name, "chaining_ops_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+
+    def test_transform_with_state_in_pandas_chaining_ops(self):
+        def check_results(batch_df, batch_id):
+            import datetime
+
+            if batch_id == 0:
+                assert batch_df.isEmpty()
+            elif batch_id == 1:
+                # eviction watermark = 15 - 5 = 10 (max event time from batch 
0),
+                # late event watermark = 0 (eviction event time from batch 0)
+                assert set(
+                    batch_df.sort("outputTimestamp").select("outputTimestamp", 
"count").collect()
+                ) == {
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
10), count=1),
+                }
+            elif batch_id == 2:
+                # eviction watermark = 25 - 5 = 20, late event watermark = 10;
+                # row with watermark=5<10 is dropped so it does not show up in 
the results;
+                # row with eventTime<=20 are finalized and emitted
+                assert set(
+                    batch_df.sort("outputTimestamp").select("outputTimestamp", 
"count").collect()
+                ) == {
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
11), count=1),
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
15), count=1),
+                }
+
+        self._test_transform_with_state_in_pandas_chaining_ops(
+            StatefulProcessorChainingOps(), check_results, "eventTime"
+        )
+        self._test_transform_with_state_in_pandas_chaining_ops(
+            StatefulProcessorChainingOps(), check_results, "eventTime", 
["outputTimestamp", "id"]
+        )
+
 
 class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
     # this dict is the same as input initial state dataframe
@@ -888,6 +991,21 @@ class SimpleStatefulProcessor(StatefulProcessor, 
unittest.TestCase):
         pass
 
 
+class StatefulProcessorChainingOps(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        pass
+
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
+        for pdf in rows:
+            timestamp_list = pdf["eventTime"].tolist()
+        yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
+
+    def close(self) -> None:
+        pass
+
+
 # A stateful processor that inherit all behavior of SimpleStatefulProcessor 
except that it use
 # ttl state with a large timeout.
 class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bed7bea61597..e05f3533ae3c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3653,7 +3653,6 @@ object CleanupAliases extends Rule[LogicalPlan] with 
AliasHelper {
 
 /**
  * Ignore event time watermark in batch query, which is only supported in 
Structured Streaming.
- * TODO: add this rule into analyzer rule list.
  */
 object EliminateEventTimeWatermark extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsWithPruning(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 4f33c26d5c3c..5b7583c763c0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -103,6 +103,7 @@ object UnsupportedOperationChecker extends Logging {
     case d: Deduplicate if d.isStreaming && d.keys.exists(hasEventTimeCol) => 
true
     case d: DeduplicateWithinWatermark if d.isStreaming => true
     case t: TransformWithState if t.isStreaming => true
+    case t: TransformWithStateInPandas if t.isStreaming => true
     case _ => false
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29216523fefc..0772c67ea27e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1031,6 +1031,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Can't prune the columns on LeafNode
     case p @ Project(_, _: LeafNode) => p
 
+    // Can't prune the columns on UpdateEventTimeWatermarkColumn
+    case p @ Project(_, _: UpdateEventTimeWatermarkColumn) => p
+
     case NestedColumnAliasing(rewrittenPlan) => rewrittenPlan
 
     // for all other logical plans that inherits the output from it's children
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 392c3edab989..6dcf01d3a9db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, 
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, 
ProductEncoder}
 import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -289,11 +289,11 @@ class KeyValueGroupedDataset[K, V] private[sql](
       transformWithState
     )
 
-    Dataset[U](sparkSession, EliminateEventTimeWatermark(
+    Dataset[U](sparkSession,
       UpdateEventTimeWatermarkColumn(
         UnresolvedAttribute(eventTimeColumnName),
         None,
-        transformWithStateDataset.logicalPlan)))
+        transformWithStateDataset.logicalPlan))
   }
 
   /** @inheritdoc */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 0974df55a6d8..6f0db42ec1f5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkRuntimeException
 import org.apache.spark.annotation.Stable
 import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -475,7 +475,8 @@ class RelationalGroupedDataset protected[sql](
       outputStructType: StructType,
       outputModeStr: String,
       timeModeStr: String,
-      initialState: RelationalGroupedDataset): DataFrame = {
+      initialState: RelationalGroupedDataset,
+      eventTimeColumnName: String): DataFrame = {
     def exprToAttr(expr: Seq[Expression]): Seq[Attribute] = {
       expr.map {
         case ne: NamedExpression => ne
@@ -529,7 +530,30 @@ class RelationalGroupedDataset protected[sql](
         initialStateSchema = initialState.df.schema
       )
     }
-    Dataset.ofRows(df.sparkSession, plan)
+    if (eventTimeColumnName.isEmpty) {
+      Dataset.ofRows(df.sparkSession, plan)
+    } else {
+      updateEventTimeColumnAfterTransformWithState(plan, eventTimeColumnName)
+    }
+  }
+
+  /**
+   * Creates a new dataset with updated eventTimeColumn after the 
transformWithState
+   * logical node.
+   */
+  private def updateEventTimeColumnAfterTransformWithState(
+      transformWithStateInPandas: LogicalPlan,
+      eventTimeColumnName: String): DataFrame = {
+    val transformWithStateDataset = Dataset.ofRows(
+      df.sparkSession,
+      transformWithStateInPandas
+    )
+
+    Dataset.ofRows(df.sparkSession,
+      UpdateEventTimeWatermarkColumn(
+        UnresolvedAttribute(eventTimeColumnName),
+        None,
+        transformWithStateDataset.logicalPlan))
   }
 
   override def toString: String = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 22082aca81a2..c621c151c0bd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{execution, AnalysisException, Strategy}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers}
@@ -966,6 +966,14 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case _: FlatMapGroupsInPandasWithState =>
         // TODO(SPARK-40443): support applyInPandasWithState in batch query
         throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
+      case t: TransformWithStateInPandas =>
+        // TODO(SPARK-50428): support TransformWithStateInPandas in batch query
+        throw new ExtendedAnalysisException(
+          new AnalysisException(
+            "_LEGACY_ERROR_TEMP_3102",
+            Map(
+              "msg" -> "TransformWithStateInPandas is not supported with batch 
DataFrames/Datasets")
+          ), plan = t)
       case logical.CoGroup(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder, 
oAttr, left, right) =>
         execution.CoGroupExec(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index 7dd4d4647eeb..617c20c3a782 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, PythonUDF, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime
 import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, 
SparkPlan}
@@ -72,6 +73,8 @@ case class TransformWithStateInPandasExec(
     initialStateSchema: StructType)
   extends BinaryExecNode with StateStoreWriter with WatermarkSupport {
 
+  override def shortName: String = "transformWithStateInPandasExec"
+
   private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
   private val pythonFunction = pythonUDF.func
   private val chainedFunc =
@@ -126,6 +129,37 @@ case class TransformWithStateInPandasExec(
     List.empty
   }
 
+  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+    if (timeMode == ProcessingTime) {
+      // TODO SPARK-50180: check if we can return true only if actual timers 
are registered,
+      //  or there is expired state
+      true
+    } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
+      eventTimeWatermarkForEviction.isDefined &&
+        newInputWatermark > eventTimeWatermarkForEviction.get
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Controls watermark propagation to downstream modes. If timeMode is
+   * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
+   * this node will not propagate watermark in this timeMode.
+   *
+   * For timeMode EventTime, output watermark is same as input Watermark 
because
+   * transformWithState does not allow users to set the event time column to be
+   * earlier than the watermark.
+   */
+  override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
+    timeMode match {
+      case ProcessingTime =>
+        None
+      case _ =>
+        Some(inputWatermarkMs)
+    }
+  }
+
   override def customStatefulOperatorMetrics: 
Seq[StatefulOperatorCustomMetric] = {
     Seq(
       // metrics around state variables
@@ -214,8 +248,15 @@ case class TransformWithStateInPandasExec(
     val updatesStartTimeNs = currentTimeNs
 
     val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, 
groupingAttributes)
-    val data =
-      groupAndProject(dataIterator, groupingAttributes, child.output, 
dedupAttributes)
+    // If timeout is based on event time, then filter late data based on 
watermark
+    val filteredIter = watermarkPredicateForDataForLateEvents match {
+      case Some(predicate) =>
+        applyRemovingRowsOlderThanWatermark(dataIterator, predicate)
+      case _ =>
+        dataIterator
+    }
+
+    val data = groupAndProject(filteredIter, groupingAttributes, child.output, 
dedupAttributes)
 
     val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
       groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs, 
metrics)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 2a7e9818aedd..719c4da14d72 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -439,6 +439,23 @@ class IncrementalExecution(
               eventTimeWatermarkForEviction = iwEviction)
           ))
 
+      // UpdateEventTimeColumnExec is used to tag the eventTime column, and 
validate
+      // emitted rows adhere to watermark in the output of 
transformWithStateInp.
+      // Hence, this node shares the same watermark value as 
TransformWithStateInPandasExec.
+      // This is the same as above in TransformWithStateExec.
+      // The only difference is TransformWithStateInPandasExec is analysed 
slightly different
+      // with no SerializeFromObjectExec wrapper.
+      case UpdateEventTimeColumnExec(eventTime, delay, None, t: 
TransformWithStateInPandasExec)
+        if t.stateInfo.isDefined =>
+        val stateInfo = t.stateInfo.get
+        val iwLateEvents = inputWatermarkForLateEvents(stateInfo)
+        val iwEviction = inputWatermarkForEviction(stateInfo)
+
+        UpdateEventTimeColumnExec(eventTime, delay, iwLateEvents,
+          t.copy(
+            eventTimeWatermarkForLateEvents = iwLateEvents,
+            eventTimeWatermarkForEviction = iwEviction)
+        )
 
       case t: TransformWithStateExec if t.stateInfo.isDefined =>
         t.copy(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index f4705b89d5a8..9c31ff0a7443 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -85,8 +85,8 @@ case class TransformWithStateExec(
 
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
     if (timeMode == ProcessingTime) {
-      // TODO: check if we can return true only if actual timers are 
registered, or there is
-      // expired state
+      // TODO SPARK-50180: check if we can return true only if actual timers 
are registered,
+      //  or there is expired state
       true
     } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
       eventTimeWatermarkForEviction.isDefined &&


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to