This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51d3efcead5b [SPARK-47233][CONNECT][SS][2/2] Client & Server logic for 
Client side streaming query listener
51d3efcead5b is described below

commit 51d3efcead5ba54b568a7be7f236179c6174e547
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Tue Apr 16 12:18:01 2024 +0900

    [SPARK-47233][CONNECT][SS][2/2] Client & Server logic for Client side 
streaming query listener
    
    ### What changes were proposed in this pull request?
    
    Server and client side for the client side listener.
    
    The client should start send a `add_listener_bus_listener` RPC for the 
first listener ever added.
    The server should start a long running thread and register a new 
"SparkConnectListenerBusListener" upon receiving the RPC, the listener should 
stream back the listener events to the client using the `responseObserver` 
created in the `executeHandler` of the `add_listener_bus_listener` call.
    
    On the client side, a spark client method: `execute_long_running_command` 
is created to continuously receive new events from the server with a 
long-running iterator. The client starts a new thread for handing such events. 
Please see the graphs below for a more detailed illustration.
    
    When either the last client side listener is removed, and the client sends 
"remove_listener_bus_listener" call, or the `send` method of 
`SparkConnectListenerBusListener` throws, the long-running server thread is 
stopped, as an effect, the final `ResultComplete` is sent to the client, 
closing the client's long-running iterator.
    
    ### Why are the changes needed?
    
    Development of spark connect streaming
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Added unit test. Removed old unit test that created for verifying 
server-side listener limitations.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46037 from WweiL/SPARK-47233-client-side-listener-2.
    
    Authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../service/SparkConnectListenerBusListener.scala  |  40 +++--
 python/pyspark/sql/connect/client/core.py          |  25 +++
 python/pyspark/sql/connect/streaming/query.py      | 196 ++++++++++++++++++---
 python/pyspark/sql/connect/streaming/readwriter.py |  12 +-
 .../connect/streaming/test_parity_listener.py      | 183 ++++++++++++++-----
 5 files changed, 376 insertions(+), 80 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
index 1b6c5179871d..56d0d920e95b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
@@ -51,7 +51,11 @@ private[sql] class ServerSideListenerHolder(val 
sessionHolder: SessionHolder) {
   val streamingQueryStartedEventCache
       : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new 
ConcurrentHashMap()
 
-  def isServerSideListenerRegistered: Boolean = 
streamingQueryServerSideListener.isDefined
+  val lock = new Object()
+
+  def isServerSideListenerRegistered: Boolean = lock.synchronized {
+    streamingQueryServerSideListener.isDefined
+  }
 
   /**
    * The initialization of the server side listener and related resources. 
This method is called
@@ -62,7 +66,7 @@ private[sql] class ServerSideListenerHolder(val 
sessionHolder: SessionHolder) {
    * @param responseObserver
    *   the responseObserver created from the first long running executeThread.
    */
-  def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+  def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = 
lock.synchronized {
     val serverListener = new SparkConnectListenerBusListener(this, 
responseObserver)
     sessionHolder.session.streams.addListener(serverListener)
     streamingQueryServerSideListener = Some(serverListener)
@@ -76,7 +80,7 @@ private[sql] class ServerSideListenerHolder(val 
sessionHolder: SessionHolder) {
    * the latch, so the long-running thread can proceed to send back the final 
ResultComplete
    * response.
    */
-  def cleanUp(): Unit = {
+  def cleanUp(): Unit = lock.synchronized {
     streamingQueryServerSideListener.foreach { listener =>
       sessionHolder.session.streams.removeListener(listener)
     }
@@ -106,18 +110,18 @@ private[sql] class SparkConnectListenerBusListener(
   // all related sources are cleaned up, and the long-running thread will 
proceed to send
   // the final ResultComplete response.
   private def send(eventJson: String, eventType: StreamingQueryEventType): 
Unit = {
-    val event = StreamingQueryListenerEvent
-      .newBuilder()
-      .setEventJson(eventJson)
-      .setEventType(eventType)
-      .build()
+    try {
+      val event = StreamingQueryListenerEvent
+        .newBuilder()
+        .setEventJson(eventJson)
+        .setEventType(eventType)
+        .build()
 
-    val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
-    val eventResult = respBuilder
-      
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
-      .build()
+      val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
+      val eventResult = respBuilder
+        
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
+        .build()
 
-    try {
       responseObserver.onNext(
         ExecutePlanResponse
           .newBuilder()
@@ -143,14 +147,24 @@ private[sql] class SparkConnectListenerBusListener(
   }
 
   override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
+    logDebug(
+      s"[SessionId: ${sessionHolder.sessionId}][UserId: 
${sessionHolder.userId}] " +
+        s"Sending QueryProgressEvent to client, id: ${event.progress.id}" +
+        s" runId: ${event.progress.runId}, batch: ${event.progress.batchId}.")
     send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT)
   }
 
   override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+    logDebug(
+      s"[SessionId: ${sessionHolder.sessionId}][UserId: 
${sessionHolder.userId}] " +
+        s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: 
${event.runId}.")
     send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
   }
 
   override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit 
= {
+    logDebug(
+      s"[SessionId: ${sessionHolder.sessionId}][UserId: 
${sessionHolder.userId}] " +
+        s"Sending QueryIdleEvent to client, id: ${event.id} runId: 
${event.runId}.")
     send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT)
   }
 }
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 667b93596c5f..b6197fff6fd5 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1067,6 +1067,28 @@ class SparkConnectClient(object):
         else:
             return (None, properties)
 
+    def execute_command_as_iterator(
+        self, command: pb2.Command, observations: Optional[Dict[str, 
Observation]] = None
+    ) -> Iterator[Dict[str, Any]]:
+        """
+        Execute given command. Similar to execute_command, but the value is 
returned using yield.
+        """
+        logger.info(f"Execute command as iterator for command 
{self._proto_to_string(command)}")
+        req = self._execute_plan_request_with_metadata()
+        if self._user_id:
+            req.user_context.user_id = self._user_id
+        req.plan.command.CopyFrom(command)
+        for response in self._execute_and_fetch_as_iterator(req, observations 
or {}):
+            if isinstance(response, dict):
+                yield response
+            else:
+                raise PySparkValueError(
+                    error_class="UNKNOWN_RESPONSE",
+                    message_parameters={
+                        "response": str(response),
+                    },
+                )
+
     def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
         """
         return if two plans have the same semantics.
@@ -1330,6 +1352,9 @@ class SparkConnectClient(object):
             if b.HasField("streaming_query_manager_command_result"):
                 cmd_result = b.streaming_query_manager_command_result
                 yield {"streaming_query_manager_command_result": cmd_result}
+            if b.HasField("streaming_query_listener_events_result"):
+                event_result = b.streaming_query_listener_events_result
+                yield {"streaming_query_listener_events_result": event_result}
             if b.HasField("get_resources_command_result"):
                 resources = {}
                 for key, resource in 
b.get_resources_command_result.resources.items():
diff --git a/python/pyspark/sql/connect/streaming/query.py 
b/python/pyspark/sql/connect/streaming/query.py
index c1940921c631..0624f8943ac4 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -20,15 +20,20 @@ check_dependencies(__name__)
 
 import json
 import sys
-import pickle
-from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional
+import warnings
+from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional, Union, 
Iterator
+from threading import Thread, Lock
 
 from pyspark.errors import StreamingQueryException, PySparkValueError
 import pyspark.sql.connect.proto as pb2
-from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.connect import proto
-from pyspark.sql.connect.utils import get_python_ver
 from pyspark.sql.streaming import StreamingQueryListener
+from pyspark.sql.streaming.listener import (
+    QueryStartedEvent,
+    QueryProgressEvent,
+    QueryIdleEvent,
+    QueryTerminatedEvent,
+)
 from pyspark.sql.streaming.query import (
     StreamingQuery as PySparkStreamingQuery,
     StreamingQueryManager as PySparkStreamingQueryManager,
@@ -36,7 +41,6 @@ from pyspark.sql.streaming.query import (
 from pyspark.errors.exceptions.connect import (
     StreamingQueryException as CapturedStreamingQueryException,
 )
-from pyspark.errors import PySparkPicklingError
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.session import SparkSession
@@ -184,6 +188,7 @@ class StreamingQuery:
 class StreamingQueryManager:
     def __init__(self, session: "SparkSession") -> None:
         self._session = session
+        self._sqlb = StreamingQueryListenerBus(self)
 
     @property
     def active(self) -> List[StreamingQuery]:
@@ -237,27 +242,13 @@ class StreamingQueryManager:
     resetTerminated.__doc__ = 
PySparkStreamingQueryManager.resetTerminated.__doc__
 
     def addListener(self, listener: StreamingQueryListener) -> None:
-        listener._init_listener_id()
-        cmd = pb2.StreamingQueryManagerCommand()
-        expr = proto.PythonUDF()
-        try:
-            expr.command = CloudPickleSerializer().dumps(listener)
-        except pickle.PicklingError:
-            raise PySparkPicklingError(
-                error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
-                message_parameters={"name": "addListener"},
-            )
-        expr.python_ver = get_python_ver()
-        cmd.add_listener.python_listener_payload.CopyFrom(expr)
-        cmd.add_listener.id = listener._id
-        self._execute_streaming_query_manager_cmd(cmd)
+        listener._set_spark_session(self._session)
+        self._sqlb.append(listener)
 
     addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__
 
     def removeListener(self, listener: StreamingQueryListener) -> None:
-        cmd = pb2.StreamingQueryManagerCommand()
-        cmd.remove_listener.id = listener._id
-        self._execute_streaming_query_manager_cmd(cmd)
+        self._sqlb.remove(listener)
 
     removeListener.__doc__ = 
PySparkStreamingQueryManager.removeListener.__doc__
 
@@ -273,6 +264,167 @@ class StreamingQueryManager:
         )
 
 
+class StreamingQueryListenerBus:
+    """
+    A client side listener bus that is responsible for buffering client side 
listeners,
+    receive listener events and invoke correct listener call backs.
+    """
+
+    def __init__(self, sqm: "StreamingQueryManager") -> None:
+        self._sqm = sqm
+        self._listener_bus: List[StreamingQueryListener] = []
+        self._execution_thread: Optional[Thread] = None
+        self._lock = Lock()
+
+    def append(self, listener: StreamingQueryListener) -> None:
+        """
+        Append a listener to the local listener bus. When the added listener is
+        the first listener, request the server to create the server side 
listener
+        and start a thread to handle query events.
+        """
+        with self._lock:
+            self._listener_bus.append(listener)
+
+            if len(self._listener_bus) == 1:
+                assert self._execution_thread is None
+                try:
+                    result_iter = self._register_server_side_listener()
+                except Exception as e:
+                    warnings.warn(
+                        f"Failed to add the listener because of exception: 
{e}\n"
+                        f"The listener is not added, please add it again."
+                    )
+                    self._listener_bus.remove(listener)
+                    return
+                self._execution_thread = Thread(
+                    target=self._query_event_handler, args=(result_iter,)
+                )
+                self._execution_thread.start()
+
+    def remove(self, listener: StreamingQueryListener) -> None:
+        """
+        Remove the listener from the local listener bus.
+
+        When the listener is not presented in the listener bus, do nothing.
+
+        When the removed listener is the last listener, ask the server to 
remove
+        the server side listener.
+        As a result, the listener handling thread created before
+        will return after processing remaining listener events. This function 
blocks until
+        all events are processed.
+        """
+        with self._lock:
+            if listener not in self._listener_bus:
+                return
+
+            if len(self._listener_bus) == 1:
+                cmd = pb2.StreamingQueryListenerBusCommand()
+                cmd.remove_listener_bus_listener = True
+                exec_cmd = pb2.Command()
+                exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
+                try:
+                    self._sqm._session.client.execute_command(exec_cmd)
+                except Exception as e:
+                    warnings.warn(
+                        f"Failed to remove the listener because of exception: 
{e}\n"
+                        f"The listener is not removed, please remove it again."
+                    )
+                    return
+                if self._execution_thread is not None:
+                    self._execution_thread.join()
+                    self._execution_thread = None
+
+            self._listener_bus.remove(listener)
+
+    def _register_server_side_listener(self) -> Iterator[Dict[str, Any]]:
+        """
+        Send add listener request to the server, after received confirmation 
from the server,
+        start a new thread to handle these events.
+        """
+        cmd = pb2.StreamingQueryListenerBusCommand()
+        cmd.add_listener_bus_listener = True
+        exec_cmd = pb2.Command()
+        exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd)
+        result_iter = 
self._sqm._session.client.execute_command_as_iterator(exec_cmd)
+        # Main thread should block until received listener_added_success 
message
+        for result in result_iter:
+            response = cast(
+                pb2.StreamingQueryListenerEventsResult,
+                result["streaming_query_listener_events_result"],
+            )
+            if response.HasField("listener_bus_listener_added"):
+                break
+        return result_iter
+
+    def _query_event_handler(self, iter: Iterator[Dict[str, Any]]) -> None:
+        """
+        Handler function passed to the new thread, if there is any error while 
receiving
+        listener events, it means the connection is unstable. In this case, 
remove all listeners
+        and tell the user to add back the listeners.
+        """
+        try:
+            for result in iter:
+                response = cast(
+                    pb2.StreamingQueryListenerEventsResult,
+                    result["streaming_query_listener_events_result"],
+                )
+                for event in response.events:
+                    deserialized_event = self.deserialize(event)
+                    self.post_to_all(deserialized_event)
+
+        except Exception as e:
+            warnings.warn(
+                "StreamingQueryListenerBus Handler thread received exception, 
all client side "
+                f"listeners are removed and handler thread is terminated. The 
error is: {e}"
+            )
+            with self._lock:
+                self._execution_thread = None
+                self._listener_bus.clear()
+            return
+
+    @staticmethod
+    def deserialize(
+        event: pb2.StreamingQueryListenerEvent,
+    ) -> Union["QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"]:
+        if event.event_type == 
proto.StreamingQueryEventType.QUERY_PROGRESS_EVENT:
+            return QueryProgressEvent.fromJson(json.loads(event.event_json))
+        elif event.event_type == 
proto.StreamingQueryEventType.QUERY_TERMINATED_EVENT:
+            return QueryTerminatedEvent.fromJson(json.loads(event.event_json))
+        elif event.event_type == 
proto.StreamingQueryEventType.QUERY_IDLE_EVENT:
+            return QueryIdleEvent.fromJson(json.loads(event.event_json))
+        else:
+            raise PySparkValueError(
+                error_class="UNKNOWN_VALUE_FOR",
+                message_parameters={"var": f"proto.StreamingQueryEventType: 
{event.event_type}"},
+            )
+
+    def post_to_all(
+        self,
+        event: Union[
+            "QueryStartedEvent", "QueryProgressEvent", "QueryIdleEvent", 
"QueryTerminatedEvent"
+        ],
+    ) -> None:
+        """
+        Post listener events to all active listeners, note that if one 
listener throws,
+        it should not affect other listeners.
+        """
+        with self._lock:
+            for listener in self._listener_bus:
+                try:
+                    if isinstance(event, QueryStartedEvent):
+                        listener.onQueryStarted(event)
+                    elif isinstance(event, QueryProgressEvent):
+                        listener.onQueryProgress(event)
+                    elif isinstance(event, QueryIdleEvent):
+                        listener.onQueryIdle(event)
+                    elif isinstance(event, QueryTerminatedEvent):
+                        listener.onQueryTerminated(event)
+                    else:
+                        warnings.warn(f"Unknown StreamingQueryListener event: 
{event}")
+                except Exception as e:
+                    warnings.warn(f"Listener {str(listener)} threw an 
exception\n{e}")
+
+
 def _test() -> None:
     import doctest
     import os
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py 
b/python/pyspark/sql/connect/streaming/readwriter.py
index ac0aca6d4b19..4973bb5b6cf7 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -18,6 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
 
+import json
 import sys
 import pickle
 from typing import cast, overload, Callable, Dict, List, Optional, 
TYPE_CHECKING, Union
@@ -31,6 +32,7 @@ from pyspark.sql.streaming.readwriter import (
     DataStreamReader as PySparkDataStreamReader,
     DataStreamWriter as PySparkDataStreamWriter,
 )
+from pyspark.sql.streaming.listener import QueryStartedEvent
 from pyspark.sql.connect.utils import get_python_ver
 from pyspark.sql.types import Row, StructType
 from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkPicklingError
@@ -599,13 +601,21 @@ class DataStreamWriter:
         start_result = cast(
             pb2.WriteStreamOperationStartResult, 
properties["write_stream_operation_start_result"]
         )
-        return StreamingQuery(
+        query = StreamingQuery(
             session=self._session,
             queryId=start_result.query_id.id,
             runId=start_result.query_id.run_id,
             name=start_result.name,
         )
 
+        if start_result.HasField("query_started_event_json"):
+            start_event = QueryStartedEvent.fromJson(
+                json.loads(start_result.query_started_event_json)
+            )
+            self._session.streams._sqlb.post_to_all(start_event)
+
+        return query
+
     def start(
         self,
         path: Optional[str] = None,
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py 
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index a15e4547f67a..be8c30c28ce0 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -25,9 +25,10 @@ from pyspark.sql.functions import count, lit
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
+# Listeners that has spark commands in callback handler functions
 # V1: Initial interface of StreamingQueryListener containing methods 
`onQueryStarted`,
 # `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5.
-class TestListenerV1(StreamingQueryListener):
+class TestListenerSparkV1(StreamingQueryListener):
     def onQueryStarted(self, event):
         e = pyspark.cloudpickle.dumps(event)
         df = self.spark.createDataFrame(data=[(e,)])
@@ -45,7 +46,7 @@ class TestListenerV1(StreamingQueryListener):
 
 
 # V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+.
-class TestListenerV2(StreamingQueryListener):
+class TestListenerSparkV2(StreamingQueryListener):
     def onQueryStarted(self, event):
         e = pyspark.cloudpickle.dumps(event)
         df = self.spark.createDataFrame(data=[(e,)])
@@ -65,8 +66,140 @@ class TestListenerV2(StreamingQueryListener):
         df.write.mode("append").saveAsTable("listener_terminated_events_v2")
 
 
+class TestListenerLocal(StreamingQueryListener):
+    def __init__(self):
+        self.start = []
+        self.progress = []
+        self.terminated = []
+
+    def onQueryStarted(self, event):
+        self.start.append(event)
+
+    def onQueryProgress(self, event):
+        self.progress.append(event)
+
+    def onQueryIdle(self, event):
+        pass
+
+    def onQueryTerminated(self, event):
+        self.terminated.append(event)
+
+
 class StreamingListenerParityTests(StreamingListenerTestsMixin, 
ReusedConnectTestCase):
-    def test_listener_events(self):
+    def test_listener_management(self):
+        listener1 = TestListenerLocal()
+        listener2 = TestListenerLocal()
+
+        try:
+            self.spark.streams.addListener(listener1)
+            self.spark.streams.addListener(listener2)
+            q = 
self.spark.readStream.format("rate").load().writeStream.format("noop").start()
+
+            # Both listeners should have listener events already because 
onQueryStarted
+            # is always called before DataStreamWriter.start() returns
+            self.assertEqual(len(listener1.start), 1)
+            self.assertEqual(len(listener2.start), 1)
+
+            # removeListener is a blocking call, resources are cleaned up by 
the time it returns
+            self.spark.streams.removeListener(listener1)
+            self.spark.streams.removeListener(listener2)
+
+            # Add back the listener and stop the query, now should see a 
terminated event
+            self.spark.streams.addListener(listener1)
+            q.stop()
+
+            # need to wait a while before QueryTerminatedEvent reaches client
+            time.sleep(15)
+            self.assertEqual(len(listener1.terminated), 1)
+
+            self.check_start_event(listener1.start[0])
+            for event in listener1.progress:
+                self.check_progress_event(event)
+            self.check_terminated_event(listener1.terminated[0])
+
+        finally:
+            for listener in self.spark.streams._sqlb._listener_bus:
+                self.spark.streams.removeListener(listener)
+            for q in self.spark.streams.active:
+                q.stop()
+
+    def test_slow_query(self):
+        try:
+            listener = TestListenerLocal()
+            self.spark.streams.addListener(listener)
+
+            slow_query = (
+                self.spark.readStream.format("rate")
+                .load()
+                .writeStream.format("noop")
+                .trigger(processingTime="20 seconds")
+                .start()
+            )
+            fast_query = (
+                
self.spark.readStream.format("rate").load().writeStream.format("noop").start()
+            )
+
+            while slow_query.lastProgress is None:
+                slow_query.awaitTermination(20)
+
+            slow_query.stop()
+            fast_query.stop()
+
+            self.assertTrue(slow_query.id in [str(e.id) for e in 
listener.start])
+            self.assertTrue(fast_query.id in [str(e.id) for e in 
listener.start])
+
+            self.assertTrue(slow_query.id in [str(e.progress.id) for e in 
listener.progress])
+            self.assertTrue(fast_query.id in [str(e.progress.id) for e in 
listener.progress])
+
+            self.assertTrue(slow_query.id in [str(e.id) for e in 
listener.terminated])
+            self.assertTrue(fast_query.id in [str(e.id) for e in 
listener.terminated])
+
+        finally:
+            for listener in self.spark.streams._sqlb._listener_bus:
+                self.spark.streams.removeListener(listener)
+            for q in self.spark.streams.active:
+                q.stop()
+
+    def test_listener_throw(self):
+        """
+        Following Vanilla Spark's behavior, when the callback of user-defined 
listener throws,
+        other listeners should still proceed.
+        """
+
+        class UselessListener(StreamingQueryListener):
+            def onQueryStarted(self, e):
+                raise Exception("My bad!")
+
+            def onQueryProgress(self, e):
+                raise Exception("My bad again!")
+
+            def onQueryTerminated(self, e):
+                raise Exception("I'm so sorry!")
+
+        try:
+            listener_good = TestListenerLocal()
+            listener_bad = UselessListener()
+            self.spark.streams.addListener(listener_good)
+            self.spark.streams.addListener(listener_bad)
+
+            q = 
self.spark.readStream.format("rate").load().writeStream.format("noop").start()
+
+            while q.lastProgress is None:
+                q.awaitTermination(0.5)
+
+            q.stop()
+            # need to wait a while before QueryTerminatedEvent reaches client
+            time.sleep(5)
+            self.assertTrue(len(listener_good.start) > 0)
+            self.assertTrue(len(listener_good.progress) > 0)
+            self.assertTrue(len(listener_good.terminated) > 0)
+        finally:
+            for listener in self.spark.streams._sqlb._listener_bus:
+                self.spark.streams.removeListener(listener)
+            for q in self.spark.streams.active:
+                q.stop()
+
+    def test_listener_events_spark_command(self):
         def verify(test_listener, table_postfix):
             try:
                 self.spark.streams.addListener(test_listener)
@@ -88,7 +221,7 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
                 self.assertTrue(q.isActive)
                 # ensure at least one batch is ran
                 while q.lastProgress is None or q.lastProgress["batchId"] == 0:
-                    time.sleep(5)
+                    q.awaitTermination(5)
                 q.stop()
                 self.assertFalse(q.isActive)
 
@@ -129,46 +262,8 @@ class 
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
             "listener_progress_events_v2",
             "listener_terminated_events_v2",
         ):
-            verify(TestListenerV1(), "_v1")
-            verify(TestListenerV2(), "_v2")
-
-    def test_accessing_spark_session(self):
-        spark = self.spark
-
-        class TestListener(StreamingQueryListener):
-            def onQueryStarted(self, event):
-                spark.createDataFrame(
-                    [("you", "can"), ("serialize", "spark")]
-                ).createOrReplaceTempView("test_accessing_spark_session")
-
-            def onQueryProgress(self, event):
-                pass
-
-            def onQueryIdle(self, event):
-                pass
-
-            def onQueryTerminated(self, event):
-                pass
-
-        self.spark.streams.addListener(TestListener())
-
-    def test_accessing_spark_session_through_df(self):
-        dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", 
"dataframe")])
-
-        class TestListener(StreamingQueryListener):
-            def onQueryStarted(self, event):
-                dataframe.collect()
-
-            def onQueryProgress(self, event):
-                pass
-
-            def onQueryIdle(self, event):
-                pass
-
-            def onQueryTerminated(self, event):
-                pass
-
-        self.spark.streams.addListener(TestListener())
+            verify(TestListenerSparkV1(), "_v1")
+            verify(TestListenerSparkV2(), "_v2")
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to