This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e6252d6c6143 [SPARK-50194][SS][PYTHON] Integration of New Timer API 
and Initial State API with Timer
e6252d6c6143 is described below

commit e6252d6c6143f134201a17bb25978a136186cbfa
Author: jingz-db <[email protected]>
AuthorDate: Thu Nov 28 15:34:37 2024 +0900

    [SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State 
API with Timer
    
    ### What changes were proposed in this pull request?
    
    As Scala side, we modify the timer API with a separate `handleExpiredTimer` 
function inside `StatefulProcessor`, this PR make a change to the timer API to 
couple with API on Scala side. Also adds a timer parameter to pass into 
`handleInitialState` function to support use cases for registering timers in 
the first batch for initial state rows.
    
    ### Why are the changes needed?
    
    This change is to couple with Scala side of APIs: 
https://github.com/apache/spark/pull/48553
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    We add a new user defined function to explicitly handle expired timeres:
    ```
    def handleExpiredTimer(
            self, key: Any, timer_values: TimerValues, expired_timer_info: 
ExpiredTimerInfo
    ```
    We also add a new timer parameter to enable users to register timers for 
keys exist in the initial state:
    ```
    def handleInitialState(
            self,
            key: Any,
            initialState: "PandasDataFrameLike",
            timer_values: TimerValues) -> None
    ```
    
    ### How was this patch tested?
    
    Add a new test in `test_pandas_transform_with_state`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48838 from jingz-db/python-new-timer.
    
    Lead-authored-by: jingz-db <[email protected]>
    Co-authored-by: Jing Zhan <[email protected]>
    Co-authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/pyspark/sql/pandas/group_ops.py             | 107 ++++++----
 python/pyspark/sql/pandas/serializers.py           |  13 +-
 python/pyspark/sql/streaming/stateful_processor.py |  49 +++--
 .../sql/streaming/stateful_processor_api_client.py |  94 +++++----
 .../sql/streaming/stateful_processor_util.py       |  27 +++
 .../pandas/test_pandas_transform_with_state.py     | 227 ++++++++++++---------
 python/pyspark/worker.py                           |  68 +++---
 .../TransformWithStateInPandasStateServer.scala    |   2 +
 8 files changed, 363 insertions(+), 224 deletions(-)

diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index d8f22e434374..688ad4b05732 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -35,6 +35,7 @@ from pyspark.sql.streaming.stateful_processor import (
     TimerValues,
 )
 from pyspark.sql.streaming.stateful_processor import StatefulProcessor, 
StatefulProcessorHandle
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.sql.types import StructType, _parse_datatype_string
 
 if TYPE_CHECKING:
@@ -503,58 +504,59 @@ class PandasGroupedOpsMixin:
         if isinstance(outputStructType, str):
             outputStructType = cast(StructType, 
_parse_datatype_string(outputStructType))
 
-        def handle_data_with_timers(
+        def handle_data_rows(
             statefulProcessorApiClient: StatefulProcessorApiClient,
             key: Any,
-            inputRows: Iterator["PandasDataFrameLike"],
+            inputRows: Optional[Iterator["PandasDataFrameLike"]] = None,
         ) -> Iterator["PandasDataFrameLike"]:
             statefulProcessorApiClient.set_implicit_key(key)
-            if timeMode != "none":
-                batch_timestamp = 
statefulProcessorApiClient.get_batch_timestamp()
-                watermark_timestamp = 
statefulProcessorApiClient.get_watermark_timestamp()
+
+            batch_timestamp, watermark_timestamp = 
statefulProcessorApiClient.get_timestamps(
+                timeMode
+            )
+
+            # process with data rows
+            if inputRows is not None:
+                data_iter = statefulProcessor.handleInputRows(
+                    key, inputRows, TimerValues(batch_timestamp, 
watermark_timestamp)
+                )
+                return data_iter
             else:
-                batch_timestamp = -1
-                watermark_timestamp = -1
-            # process with invalid expiry timer info and emit data rows
-            data_iter = statefulProcessor.handleInputRows(
-                key,
-                inputRows,
-                TimerValues(batch_timestamp, watermark_timestamp),
-                ExpiredTimerInfo(False),
+                return iter([])
+
+        def handle_expired_timers(
+            statefulProcessorApiClient: StatefulProcessorApiClient,
+        ) -> Iterator["PandasDataFrameLike"]:
+            batch_timestamp, watermark_timestamp = 
statefulProcessorApiClient.get_timestamps(
+                timeMode
             )
-            
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
 
-            if timeMode == "processingtime":
+            if timeMode.lower() == "processingtime":
                 expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(
                     batch_timestamp
                 )
-            elif timeMode == "eventtime":
+            elif timeMode.lower() == "eventtime":
                 expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(
                     watermark_timestamp
                 )
             else:
                 expiry_list_iter = iter([[]])
 
-            result_iter_list = [data_iter]
-            # process with valid expiry time info and with empty input rows,
-            # only timer related rows will be emitted
+            # process with expiry timers, only timer related rows will be 
emitted
             for expiry_list in expiry_list_iter:
                 for key_obj, expiry_timestamp in expiry_list:
-                    result_iter_list.append(
-                        statefulProcessor.handleInputRows(
-                            key_obj,
-                            iter([]),
-                            TimerValues(batch_timestamp, watermark_timestamp),
-                            ExpiredTimerInfo(True, expiry_timestamp),
-                        )
-                    )
-            # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
-
-            result = itertools.chain(*result_iter_list)
-            return result
+                    statefulProcessorApiClient.set_implicit_key(key_obj)
+                    for pd in statefulProcessor.handleExpiredTimer(
+                        key=key_obj,
+                        timer_values=TimerValues(batch_timestamp, 
watermark_timestamp),
+                        expired_timer_info=ExpiredTimerInfo(expiry_timestamp),
+                    ):
+                        yield pd
+                    statefulProcessorApiClient.delete_timer(expiry_timestamp)
 
         def transformWithStateUDF(
             statefulProcessorApiClient: StatefulProcessorApiClient,
+            mode: TransformWithStateInPandasFuncMode,
             key: Any,
             inputRows: Iterator["PandasDataFrameLike"],
         ) -> Iterator["PandasDataFrameLike"]:
@@ -566,19 +568,28 @@ class PandasGroupedOpsMixin:
                     StatefulProcessorHandleState.INITIALIZED
                 )
 
-            # Key is None when we have processed all the input data from the 
worker and ready to
-            # proceed with the cleanup steps.
-            if key is None:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.DATA_PROCESSED
+                )
+                result = handle_expired_timers(statefulProcessorApiClient)
+                return result
+            elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.TIMER_PROCESSED
+                )
                 statefulProcessorApiClient.remove_implicit_key()
                 statefulProcessor.close()
                 
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
                 return iter([])
-
-            result = handle_data_with_timers(statefulProcessorApiClient, key, 
inputRows)
-            return result
+            else:
+                # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+                result = handle_data_rows(statefulProcessorApiClient, key, 
inputRows)
+                return result
 
         def transformWithStateWithInitStateUDF(
             statefulProcessorApiClient: StatefulProcessorApiClient,
+            mode: TransformWithStateInPandasFuncMode,
             key: Any,
             inputRows: Iterator["PandasDataFrameLike"],
             initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
@@ -603,20 +614,30 @@ class PandasGroupedOpsMixin:
                     StatefulProcessorHandleState.INITIALIZED
                 )
 
-            # Key is None when we have processed all the input data from the 
worker and ready to
-            # proceed with the cleanup steps.
-            if key is None:
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+                statefulProcessorApiClient.set_handle_state(
+                    StatefulProcessorHandleState.DATA_PROCESSED
+                )
+                result = handle_expired_timers(statefulProcessorApiClient)
+                return result
+            elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
                 statefulProcessorApiClient.remove_implicit_key()
                 statefulProcessor.close()
                 
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
                 return iter([])
+            else:
+                # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+                batch_timestamp, watermark_timestamp = 
statefulProcessorApiClient.get_timestamps(
+                    timeMode
+                )
 
             # only process initial state if first batch and initial state is 
not None
             if initialStates is not None:
                 for cur_initial_state in initialStates:
                     statefulProcessorApiClient.set_implicit_key(key)
-                    # TODO(SPARK-50194) integration with new timer API with 
initial state
-                    statefulProcessor.handleInitialState(key, 
cur_initial_state)
+                    statefulProcessor.handleInitialState(
+                        key, cur_initial_state, TimerValues(batch_timestamp, 
watermark_timestamp)
+                    )
 
             # if we don't have input rows for the given key but only have 
initial state
             # for the grouping key, the inputRows iterator could be empty
@@ -629,7 +650,7 @@ class PandasGroupedOpsMixin:
                 inputRows = itertools.chain([first], inputRows)
 
             if not input_rows_empty:
-                result = handle_data_with_timers(statefulProcessorApiClient, 
key, inputRows)
+                result = handle_data_rows(statefulProcessorApiClient, key, 
inputRows)
             else:
                 result = iter([])
 
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 5bf07b87400f..536bf7307065 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -36,6 +36,7 @@ from pyspark.sql.pandas.types import (
     _create_converter_from_pandas,
     _create_converter_to_pandas,
 )
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.sql.types import (
     DataType,
     StringType,
@@ -1197,7 +1198,11 @@ class 
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (k, g)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
 
     def dump_stream(self, iterator, stream):
         """
@@ -1281,4 +1286,8 @@ class 
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (k, g)
+            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+
+        yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
diff --git a/python/pyspark/sql/streaming/stateful_processor.py 
b/python/pyspark/sql/streaming/stateful_processor.py
index 20078c215bac..9caa9304d6a8 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -105,21 +105,13 @@ class TimerValues:
 
 class ExpiredTimerInfo:
     """
-    Class used for arbitrary stateful operations with transformWithState to 
access expired timer
-    info. When is_valid is false, the expiry timestamp is invalid.
+    Class used to provide access to expired timer's expiry time.
     .. versionadded:: 4.0.0
     """
 
-    def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
-        self._is_valid = is_valid
+    def __init__(self, expiry_time_in_ms: int = -1) -> None:
         self._expiry_time_in_ms = expiry_time_in_ms
 
-    def is_valid(self) -> bool:
-        """
-        Whether the expiry info is valid.
-        """
-        return self._is_valid
-
     def get_expiry_time_in_ms(self) -> int:
         """
         Get the timestamp for expired timer, return timestamp in millisecond.
@@ -398,7 +390,6 @@ class StatefulProcessor(ABC):
         key: Any,
         rows: Iterator["PandasDataFrameLike"],
         timer_values: TimerValues,
-        expired_timer_info: ExpiredTimerInfo,
     ) -> Iterator["PandasDataFrameLike"]:
         """
         Function that will allow users to interact with input data rows along 
with the grouping key.
@@ -420,11 +411,29 @@ class StatefulProcessor(ABC):
         timer_values: TimerValues
                       Timer value for the current batch that process the input 
rows.
                       Users can get the processing or event time timestamp 
from TimerValues.
-        expired_timer_info: ExpiredTimerInfo
-                            Timestamp of expired timers on the grouping key.
         """
         ...
 
+    def handleExpiredTimer(
+        self, key: Any, timer_values: TimerValues, expired_timer_info: 
ExpiredTimerInfo
+    ) -> Iterator["PandasDataFrameLike"]:
+        """
+        Optional to implement. Will act return an empty iterator if not 
defined.
+        Function that will be invoked when a timer is fired for a given key. 
Users can choose to
+        evict state, register new timers and optionally provide output rows.
+
+        Parameters
+        ----------
+        key : Any
+            grouping key.
+        timer_values: TimerValues
+                      Timer value for the current batch that process the input 
rows.
+                      Users can get the processing or event time timestamp 
from TimerValues.
+        expired_timer_info: ExpiredTimerInfo
+                            Instance of ExpiredTimerInfo that provides access 
to expired timer.
+        """
+        return iter([])
+
     @abstractmethod
     def close(self) -> None:
         """
@@ -433,9 +442,21 @@ class StatefulProcessor(ABC):
         """
         ...
 
-    def handleInitialState(self, key: Any, initialState: 
"PandasDataFrameLike") -> None:
+    def handleInitialState(
+        self, key: Any, initialState: "PandasDataFrameLike", timer_values: 
TimerValues
+    ) -> None:
         """
         Optional to implement. Will act as no-op if not defined or no initial 
state input.
          Function that will be invoked only in the first batch for users to 
process initial states.
+
+        Parameters
+        ----------
+        key : Any
+            grouping key.
+        initialState: :class:`pandas.DataFrame`
+                      One dataframe in the initial state associated with the 
key.
+        timer_values: TimerValues
+                      Timer value for the current batch that process the input 
rows.
+                      Users can get the processing or event time timestamp 
from TimerValues.
         """
         pass
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 353f75e26796..53704188081c 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -62,6 +62,10 @@ class StatefulProcessorApiClient:
         # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
         # and the index of the last row that was read.
         self.list_timer_iterator_cursors: Dict[str, 
Tuple["PandasDataFrameLike", int]] = {}
+        # statefulProcessorApiClient is initialized per batch per partition,
+        # so we will have new timestamps for a new batch
+        self._batch_timestamp = -1
+        self._watermark_timestamp = -1
 
     def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -266,47 +270,15 @@ class StatefulProcessorApiClient:
                 # TODO(SPARK-49233): Classify user facing errors.
                 raise PySparkRuntimeError(f"Error getting expiry timers: " 
f"{response_message[1]}")
 
-    def get_batch_timestamp(self) -> int:
-        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
-
-        get_processing_time_call = stateMessage.GetProcessingTime()
-        timer_value_call = stateMessage.TimerValueRequest(
-            getProcessingTimer=get_processing_time_call
-        )
-        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
-        message = stateMessage.StateRequest(timerRequest=timer_request)
-
-        self._send_proto_message(message.SerializeToString())
-        response_message = self._receive_proto_message_with_long_value()
-        status = response_message[0]
-        if status != 0:
-            # TODO(SPARK-49233): Classify user facing errors.
-            raise PySparkRuntimeError(
-                f"Error getting processing timestamp: " 
f"{response_message[1]}"
-            )
+    def get_timestamps(self, time_mode: str) -> Tuple[int, int]:
+        if time_mode.lower() == "none":
+            return -1, -1
         else:
-            timestamp = response_message[2]
-            return timestamp
-
-    def get_watermark_timestamp(self) -> int:
-        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
-
-        get_watermark_call = stateMessage.GetWatermark()
-        timer_value_call = 
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
-        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
-        message = stateMessage.StateRequest(timerRequest=timer_request)
-
-        self._send_proto_message(message.SerializeToString())
-        response_message = self._receive_proto_message_with_long_value()
-        status = response_message[0]
-        if status != 0:
-            # TODO(SPARK-49233): Classify user facing errors.
-            raise PySparkRuntimeError(
-                f"Error getting eventtime timestamp: " f"{response_message[1]}"
-            )
-        else:
-            timestamp = response_message[2]
-            return timestamp
+            if self._batch_timestamp == -1:
+                self._batch_timestamp = self._get_batch_timestamp()
+            if self._watermark_timestamp == -1:
+                self._watermark_timestamp = self._get_watermark_timestamp()
+        return self._batch_timestamp, self._watermark_timestamp
 
     def get_map_state(
         self,
@@ -353,6 +325,48 @@ class StatefulProcessorApiClient:
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error deleting state: " 
f"{response_message[1]}")
 
+    def _get_batch_timestamp(self) -> int:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        get_processing_time_call = stateMessage.GetProcessingTime()
+        timer_value_call = stateMessage.TimerValueRequest(
+            getProcessingTimer=get_processing_time_call
+        )
+        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+        message = stateMessage.StateRequest(timerRequest=timer_request)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message_with_long_value()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error getting processing timestamp: " 
f"{response_message[1]}"
+            )
+        else:
+            timestamp = response_message[2]
+            return timestamp
+
+    def _get_watermark_timestamp(self) -> int:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        get_watermark_call = stateMessage.GetWatermark()
+        timer_value_call = 
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
+        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+        message = stateMessage.StateRequest(timerRequest=timer_request)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message_with_long_value()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error getting eventtime timestamp: " f"{response_message[1]}"
+            )
+        else:
+            timestamp = response_message[2]
+            return timestamp
+
     def _send_proto_message(self, message: bytes) -> None:
         # Writing zero here to indicate message version. This allows us to 
evolve the message
         # format or even changing the message protocol in the future.
diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py 
b/python/pyspark/sql/streaming/stateful_processor_util.py
new file mode 100644
index 000000000000..6130a9581bc2
--- /dev/null
+++ b/python/pyspark/sql/streaming/stateful_processor_util.py
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+
+# This file places the utilities for transformWithStateInPandas; we have a 
separate file to avoid
+# putting internal classes to the stateful_processor.py file which contains 
public APIs.
+
+
+class TransformWithStateInPandasFuncMode(Enum):
+    PROCESS_DATA = 1
+    PROCESS_TIMER = 2
+    COMPLETE = 3
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index f385d7cd1abc..60f2c9348db3 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -55,6 +55,7 @@ class TransformWithStateInPandasTestsMixin:
             
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
         )
         
cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch",
 "2")
+        cfg.set("spark.sql.session.timeZone", "UTC")
         return cfg
 
     def _prepare_input_data(self, input_path, col1, col2):
@@ -558,14 +559,25 @@ class TransformWithStateInPandasTestsMixin:
     def test_transform_with_state_in_pandas_event_time(self):
         def check_results(batch_df, batch_id):
             if batch_id == 0:
-                assert set(batch_df.sort("id").collect()) == {Row(id="a", 
timestamp="20")}
-            elif batch_id == 1:
+                # watermark for late event = 0
+                # watermark for eviction = 0
+                # timer is registered with expiration time = 0, hence expired 
at the same batch
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="a", timestamp="20"),
                     Row(id="a-expired", timestamp="0"),
                 }
+            elif batch_id == 1:
+                # watermark for late event = 0
+                # watermark for eviction = 10 (20 - 10)
+                # timer is registered with expiration time = 10, hence expired 
at the same batch
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="4"),
+                    Row(id="a-expired", timestamp="10000"),
+                }
             elif batch_id == 2:
-                # verify that rows and expired timer produce the expected 
result
+                # watermark for late event = 10
+                # watermark for eviction = 10 (unchanged as 4 < 10)
+                # timer is registered with expiration time = 10, hence expired 
at the same batch
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="a", timestamp="15"),
                     Row(id="a-expired", timestamp="10000"),
@@ -578,7 +590,9 @@ class TransformWithStateInPandasTestsMixin:
             EventTimeStatefulProcessor(), check_results
         )
 
-    def _test_transform_with_state_init_state_in_pandas(self, 
stateful_processor, check_results):
+    def _test_transform_with_state_init_state_in_pandas(
+        self, stateful_processor, check_results, time_mode="None"
+    ):
         input_path = tempfile.mkdtemp()
         self._prepare_test_resource1(input_path)
         time.sleep(2)
@@ -606,7 +620,7 @@ class TransformWithStateInPandasTestsMixin:
                 statefulProcessor=stateful_processor,
                 outputStructType=output_schema,
                 outputMode="Update",
-                timeMode="None",
+                timeMode=time_mode,
                 initialState=initial_state,
             )
             .writeStream.queryName("this_query")
@@ -806,6 +820,45 @@ class TransformWithStateInPandasTestsMixin:
             StatefulProcessorChainingOps(), check_results, "eventTime", 
["outputTimestamp", "id"]
         )
 
+    def test_transform_with_state_init_state_with_timers(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                # timers are registered and handled in the first batch for
+                # rows in initial state; For key=0 and key=3 which contains
+                # expired timers, both should be handled by handleExpiredTimers
+                # regardless of whether key exists in the data rows or not
+                expired_df = 
batch_df.filter(batch_df["id"].contains("expired"))
+                data_df = batch_df.filter(~batch_df["id"].contains("expired"))
+                assert set(expired_df.sort("id").select("id").collect()) == {
+                    Row(id="0-expired"),
+                    Row(id="3-expired"),
+                }
+                assert set(data_df.sort("id").collect()) == {
+                    Row(id="0", value=str(789 + 123 + 46)),
+                    Row(id="1", value=str(146 + 346)),
+                }
+            elif batch_id == 1:
+                # handleInitialState is only processed in the first batch,
+                # no more timer is registered so no more expired timers
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", value=str(789 + 123 + 46 + 67)),
+                    Row(id="3", value=str(987 + 12)),
+                }
+            else:
+                for q in self.spark.streams.active:
+                    q.stop()
+
+        self._test_transform_with_state_init_state_in_pandas(
+            StatefulProcessorWithInitialStateTimers(), check_results, 
"processingTime"
+        )
+
+    # run the same test suites again but with single shuffle partition
+    def test_transform_with_state_with_timers_single_partition(self):
+        with self.sql_conf({"spark.sql.shuffle.partitions": "1"}):
+            self.test_transform_with_state_init_state_with_timers()
+            self.test_transform_with_state_in_pandas_event_time()
+            self.test_transform_with_state_in_pandas_proc_timer()
+
 
 class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
     # this dict is the same as input initial state dataframe
@@ -814,10 +867,9 @@ class 
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
     def init(self, handle: StatefulProcessorHandle) -> None:
         state_schema = StructType([StructField("value", IntegerType(), True)])
         self.value_state = handle.getValueState("value_state", state_schema)
+        self.handle = handle
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         exists = self.value_state.exists()
         if exists:
             value_row = self.value_state.get()
@@ -840,7 +892,7 @@ class 
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
         else:
             yield pd.DataFrame({"id": key, "value": str(accumulated_value)})
 
-    def handleInitialState(self, key, initialState) -> None:
+    def handleInitialState(self, key, initialState, timer_values) -> None:
         init_val = initialState.at[0, "initVal"]
         self.value_state.update((init_val,))
         if len(key) == 1:
@@ -850,6 +902,19 @@ class 
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
         pass
 
 
+class 
StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState):
+    def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> 
Iterator[pd.DataFrame]:
+        self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+        str_key = f"{str(key[0])}-expired"
+        yield pd.DataFrame(
+            {"id": (str_key,), "value": 
str(expired_timer_info.get_expiry_time_in_ms())}
+        )
+
+    def handleInitialState(self, key, initialState, timer_values) -> None:
+        super().handleInitialState(key, initialState, timer_values)
+        
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1)
+
+
 # A stateful processor that output the max event time it has seen. Register 
timer for
 # current watermark. Clear max state if timer expires.
 class EventTimeStatefulProcessor(StatefulProcessor):
@@ -858,33 +923,30 @@ class EventTimeStatefulProcessor(StatefulProcessor):
         self.handle = handle
         self.max_state = handle.getValueState("max_state", state_schema)
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
-        if expired_timer_info.is_valid():
-            self.max_state.clear()
-            self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
-            str_key = f"{str(key[0])}-expired"
-            yield pd.DataFrame(
-                {"id": (str_key,), "timestamp": 
str(expired_timer_info.get_expiry_time_in_ms())}
-            )
+    def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> 
Iterator[pd.DataFrame]:
+        self.max_state.clear()
+        self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+        str_key = f"{str(key[0])}-expired"
+        yield pd.DataFrame(
+            {"id": (str_key,), "timestamp": 
str(expired_timer_info.get_expiry_time_in_ms())}
+        )
 
-        else:
-            timestamp_list = []
-            for pdf in rows:
-                # int64 will represent timestamp in nanosecond, restore to 
second
-                timestamp_list.extend((pdf["eventTime"].astype("int64") // 
10**9).tolist())
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
+        timestamp_list = []
+        for pdf in rows:
+            # int64 will represent timestamp in nanosecond, restore to second
+            timestamp_list.extend((pdf["eventTime"].astype("int64") // 
10**9).tolist())
 
-            if self.max_state.exists():
-                cur_max = int(self.max_state.get()[0])
-            else:
-                cur_max = 0
-            max_event_time = str(max(cur_max, max(timestamp_list)))
+        if self.max_state.exists():
+            cur_max = int(self.max_state.get()[0])
+        else:
+            cur_max = 0
+        max_event_time = str(max(cur_max, max(timestamp_list)))
 
-            self.max_state.update((max_event_time,))
-            
self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
+        self.max_state.update((max_event_time,))
+        self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
 
-            yield pd.DataFrame({"id": key, "timestamp": max_event_time})
+        yield pd.DataFrame({"id": key, "timestamp": max_event_time})
 
     def close(self) -> None:
         pass
@@ -898,54 +960,49 @@ class ProcTimeStatefulProcessor(StatefulProcessor):
         self.handle = handle
         self.count_state = handle.getValueState("count_state", state_schema)
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
-        if expired_timer_info.is_valid():
-            # reset count state each time the timer is expired
-            timer_list_1 = [e for e in self.handle.listTimers()]
-            timer_list_2 = []
-            idx = 0
-            for e in self.handle.listTimers():
-                timer_list_2.append(e)
-                # check multiple iterator on the same grouping key works
-                assert timer_list_2[idx] == timer_list_1[idx]
-                idx += 1
-
-            if len(timer_list_1) > 0:
-                # before deleting the expiring timers, there are 2 timers -
-                # one timer we just registered, and one that is going to be 
deleted
-                assert len(timer_list_1) == 2
-            self.count_state.clear()
-            self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
-            yield pd.DataFrame(
-                {
-                    "id": key,
-                    "countAsString": str("-1"),
-                    "timeValues": 
str(expired_timer_info.get_expiry_time_in_ms()),
-                }
-            )
+    def handleExpiredTimer(self, key, timer_values, expired_timer_info) -> 
Iterator[pd.DataFrame]:
+        # reset count state each time the timer is expired
+        timer_list_1 = [e for e in self.handle.listTimers()]
+        timer_list_2 = []
+        idx = 0
+        for e in self.handle.listTimers():
+            timer_list_2.append(e)
+            # check multiple iterator on the same grouping key works
+            assert timer_list_2[idx] == timer_list_1[idx]
+            idx += 1
+
+        if len(timer_list_1) > 0:
+            assert len(timer_list_1) == 2
+        self.count_state.clear()
+        self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+        yield pd.DataFrame(
+            {
+                "id": key,
+                "countAsString": str("-1"),
+                "timeValues": str(expired_timer_info.get_expiry_time_in_ms()),
+            }
+        )
 
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
+        if not self.count_state.exists():
+            count = 0
         else:
-            if not self.count_state.exists():
-                count = 0
-            else:
-                count = int(self.count_state.get()[0])
+            count = int(self.count_state.get()[0])
 
-            if key == ("0",):
-                
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms())
+        if key == ("0",):
+            
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1)
 
-            rows_count = 0
-            for pdf in rows:
-                pdf_count = len(pdf)
-                rows_count += pdf_count
+        rows_count = 0
+        for pdf in rows:
+            pdf_count = len(pdf)
+            rows_count += pdf_count
 
-            count = count + rows_count
+        count = count + rows_count
 
-            self.count_state.update((str(count),))
-            timestamp = str(timer_values.get_current_processing_time_in_ms())
+        self.count_state.update((str(count),))
+        timestamp = str(timer_values.get_current_processing_time_in_ms())
 
-            yield pd.DataFrame({"id": key, "countAsString": str(count), 
"timeValues": timestamp})
+        yield pd.DataFrame({"id": key, "countAsString": str(count), 
"timeValues": timestamp})
 
     def close(self) -> None:
         pass
@@ -961,9 +1018,7 @@ class SimpleStatefulProcessor(StatefulProcessor, 
unittest.TestCase):
         self.temp_state = handle.getValueState("tempState", state_schema)
         handle.deleteIfExists("tempState")
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value 
state exists"):
             self.temp_state.exists()
         new_violations = 0
@@ -995,9 +1050,7 @@ class StatefulProcessorChainingOps(StatefulProcessor):
     def init(self, handle: StatefulProcessorHandle) -> None:
         pass
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         for pdf in rows:
             timestamp_list = pdf["eventTime"].tolist()
         yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
@@ -1027,9 +1080,7 @@ class TTLStatefulProcessor(StatefulProcessor):
             "ttl-map-state", user_key_schema, state_schema, 10000
         )
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         count = 0
         ttl_count = 0
         ttl_list_state_count = 0
@@ -1079,9 +1130,7 @@ class InvalidSimpleStatefulProcessor(StatefulProcessor):
         state_schema = StructType([StructField("value", IntegerType(), True)])
         self.num_violations_state = handle.getValueState("numViolations", 
state_schema)
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         count = 0
         exists = self.num_violations_state.exists()
         assert not exists
@@ -1105,9 +1154,7 @@ class ListStateProcessor(StatefulProcessor):
         self.list_state1 = handle.getListState("listState1", state_schema)
         self.list_state2 = handle.getListState("listState2", state_schema)
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         count = 0
         for pdf in rows:
             list_state_rows = [(120,), (20,)]
@@ -1162,9 +1209,7 @@ class MapStateProcessor(StatefulProcessor):
         value_schema = StructType([StructField("count", IntegerType(), True)])
         self.map_state = handle.getMapState("mapState", key_schema, 
value_schema)
 
-    def handleInputRows(
-        self, key, rows, timer_values, expired_timer_info
-    ) -> Iterator[pd.DataFrame]:
+    def handleInputRows(self, key, rows, timer_values) -> 
Iterator[pd.DataFrame]:
         count = 0
         key1 = ("key1",)
         key2 = ("key2",)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 04f95e9f5264..1ebc04520eca 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,6 +34,7 @@ from pyspark.accumulators import (
     _deserialize_accumulator,
 )
 from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
 from pyspark.resource import ResourceInformation
 from pyspark.util import PythonEvalType, local_connect_and_auth
@@ -493,36 +494,36 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
 
 
 def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
-    def wrapped(stateful_processor_api_client, key, value_series_gen):
+    def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
         import pandas as pd
 
         values = (pd.concat(x, axis=1) for x in value_series_gen)
-        result_iter = f(stateful_processor_api_client, key, values)
+        result_iter = f(stateful_processor_api_client, mode, key, values)
 
         # TODO(SPARK-49100): add verification that elements in result_iter are
         # indeed of type pd.DataFrame and confirm to assigned cols
 
         return result_iter
 
-    return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
+    return lambda p, m, k, v: [(wrapped(p, m, k, v), 
to_arrow_type(return_type))]
 
 
 def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type, 
runner_conf):
-    def wrapped(stateful_processor_api_client, key, value_series_gen):
+    def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
         import pandas as pd
 
         state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
         state_values = (df for x, _ in state_values_gen if not (df := 
pd.concat(x, axis=1)).empty)
         init_states = (df for _, x in init_states_gen if not (df := 
pd.concat(x, axis=1)).empty)
 
-        result_iter = f(stateful_processor_api_client, key, state_values, 
init_states)
+        result_iter = f(stateful_processor_api_client, mode, key, 
state_values, init_states)
 
         # TODO(SPARK-49100): add verification that elements in result_iter are
         # indeed of type pd.DataFrame and confirm to assigned cols
 
         return result_iter
 
-    return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
+    return lambda p, m, k, v: [(wrapped(p, m, k, v), 
to_arrow_type(return_type))]
 
 
 def wrap_grouped_map_pandas_udf_with_state(f, return_type):
@@ -1697,18 +1698,22 @@ def read_udfs(pickleSer, infile, eval_type):
         ser.key_offsets = parsed_offsets[0][0]
         stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
 
-        # Create function like this:
-        #   mapper a: f([a[0]], [a[0], a[1]])
         def mapper(a):
-            key = a[0]
+            mode = a[0]
 
-            def values_gen():
-                for x in a[1]:
-                    retVal = [x[1][o] for o in parsed_offsets[0][1]]
-                    yield retVal
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+                key = a[1]
 
-            # This must be generator comprehension - do not materialize.
-            return f(stateful_processor_api_client, key, values_gen())
+                def values_gen():
+                    for x in a[2]:
+                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
+                        yield retVal
+
+                # This must be generator comprehension - do not materialize.
+                return f(stateful_processor_api_client, mode, key, 
values_gen())
+            else:
+                # mode == PROCESS_TIMER or mode == COMPLETE
+                return f(stateful_processor_api_client, mode, None, iter([]))
 
     elif eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
         # We assume there is only one UDF here because grouped map doesn't
@@ -1731,16 +1736,22 @@ def read_udfs(pickleSer, infile, eval_type):
         stateful_processor_api_client = 
StatefulProcessorApiClient(state_server_port, key_schema)
 
         def mapper(a):
-            key = a[0]
+            mode = a[0]
 
-            def values_gen():
-                for x in a[1]:
-                    retVal = [x[1][o] for o in parsed_offsets[0][1]]
-                    initVal = [x[2][o] for o in parsed_offsets[1][1]]
-                    yield retVal, initVal
+            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+                key = a[1]
 
-            # This must be generator comprehension - do not materialize.
-            return f(stateful_processor_api_client, key, values_gen())
+                def values_gen():
+                    for x in a[2]:
+                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
+                        initVal = [x[2][o] for o in parsed_offsets[1][1]]
+                        yield retVal, initVal
+
+                # This must be generator comprehension - do not materialize.
+                return f(stateful_processor_api_client, mode, key, 
values_gen())
+            else:
+                # mode == PROCESS_TIMER or mode == COMPLETE
+                return f(stateful_processor_api_client, mode, None, iter([]))
 
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
         import pyarrow as pa
@@ -1958,17 +1969,6 @@ def main(infile, outfile):
             try:
                 serializer.dump_stream(out_iter, outfile)
             finally:
-                # Sending a signal to TransformWithState UDF to perform proper 
cleanup steps.
-                if (
-                    eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
-                    or eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
-                ):
-                    # Sending key as None to indicate that process() has 
finished.
-                    end_iter = func(split_index, iter([(None, None)]))
-                    # Need to materialize the iterator to trigger the cleanup 
steps, nothing needs
-                    # to be done here.
-                    for _ in end_iter:
-                        pass
                 if hasattr(out_iter, "close"):
                     out_iter.close()
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 0373c8607ff2..2957f4b38758 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -120,6 +120,8 @@ class TransformWithStateInPandasStateServer(
   }
 
   /** Timer related class variables */
+  // An iterator to store all expired timer info. This is meant to be consumed 
only once per
+  // partition. This should be called after finishing handling all input rows.
   private var expiryTimestampIter: Option[Iterator[(Any, Long)]] =
     if (expiryTimerIterForTest != null) {
       Option(expiryTimerIterForTest)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to