This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 51d3efcead5b [SPARK-47233][CONNECT][SS][2/2] Client & Server logic for Client side streaming query listener 51d3efcead5b is described below commit 51d3efcead5ba54b568a7be7f236179c6174e547 Author: Wei Liu <wei....@databricks.com> AuthorDate: Tue Apr 16 12:18:01 2024 +0900 [SPARK-47233][CONNECT][SS][2/2] Client & Server logic for Client side streaming query listener ### What changes were proposed in this pull request? Server and client side for the client side listener. The client should start send a `add_listener_bus_listener` RPC for the first listener ever added. The server should start a long running thread and register a new "SparkConnectListenerBusListener" upon receiving the RPC, the listener should stream back the listener events to the client using the `responseObserver` created in the `executeHandler` of the `add_listener_bus_listener` call. On the client side, a spark client method: `execute_long_running_command` is created to continuously receive new events from the server with a long-running iterator. The client starts a new thread for handing such events. Please see the graphs below for a more detailed illustration. When either the last client side listener is removed, and the client sends "remove_listener_bus_listener" call, or the `send` method of `SparkConnectListenerBusListener` throws, the long-running server thread is stopped, as an effect, the final `ResultComplete` is sent to the client, closing the client's long-running iterator. ### Why are the changes needed? Development of spark connect streaming ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Added unit test. Removed old unit test that created for verifying server-side listener limitations. ### Was this patch authored or co-authored using generative AI tooling? No Closes #46037 from WweiL/SPARK-47233-client-side-listener-2. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../service/SparkConnectListenerBusListener.scala | 40 +++-- python/pyspark/sql/connect/client/core.py | 25 +++ python/pyspark/sql/connect/streaming/query.py | 196 ++++++++++++++++++--- python/pyspark/sql/connect/streaming/readwriter.py | 12 +- .../connect/streaming/test_parity_listener.py | 183 ++++++++++++++----- 5 files changed, 376 insertions(+), 80 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 1b6c5179871d..56d0d920e95b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -51,7 +51,11 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { val streamingQueryStartedEventCache : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() - def isServerSideListenerRegistered: Boolean = streamingQueryServerSideListener.isDefined + val lock = new Object() + + def isServerSideListenerRegistered: Boolean = lock.synchronized { + streamingQueryServerSideListener.isDefined + } /** * The initialization of the server side listener and related resources. This method is called @@ -62,7 +66,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * @param responseObserver * the responseObserver created from the first long running executeThread. */ - def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized { val serverListener = new SparkConnectListenerBusListener(this, responseObserver) sessionHolder.session.streams.addListener(serverListener) streamingQueryServerSideListener = Some(serverListener) @@ -76,7 +80,7 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * the latch, so the long-running thread can proceed to send back the final ResultComplete * response. */ - def cleanUp(): Unit = { + def cleanUp(): Unit = lock.synchronized { streamingQueryServerSideListener.foreach { listener => sessionHolder.session.streams.removeListener(listener) } @@ -106,18 +110,18 @@ private[sql] class SparkConnectListenerBusListener( // all related sources are cleaned up, and the long-running thread will proceed to send // the final ResultComplete response. private def send(eventJson: String, eventType: StreamingQueryEventType): Unit = { - val event = StreamingQueryListenerEvent - .newBuilder() - .setEventJson(eventJson) - .setEventType(eventType) - .build() + try { + val event = StreamingQueryListenerEvent + .newBuilder() + .setEventJson(eventJson) + .setEventType(eventType) + .build() - val respBuilder = StreamingQueryListenerEventsResult.newBuilder() - val eventResult = respBuilder - .addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava) - .build() + val respBuilder = StreamingQueryListenerEventsResult.newBuilder() + val eventResult = respBuilder + .addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava) + .build() - try { responseObserver.onNext( ExecutePlanResponse .newBuilder() @@ -143,14 +147,24 @@ private[sql] class SparkConnectListenerBusListener( } override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + logDebug( + s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " + + s"Sending QueryProgressEvent to client, id: ${event.progress.id}" + + s" runId: ${event.progress.runId}, batch: ${event.progress.batchId}.") send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT) } override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + logDebug( + s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " + + s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: ${event.runId}.") send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT) } override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { + logDebug( + s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " + + s"Sending QueryIdleEvent to client, id: ${event.id} runId: ${event.runId}.") send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT) } } diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 667b93596c5f..b6197fff6fd5 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1067,6 +1067,28 @@ class SparkConnectClient(object): else: return (None, properties) + def execute_command_as_iterator( + self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None + ) -> Iterator[Dict[str, Any]]: + """ + Execute given command. Similar to execute_command, but the value is returned using yield. + """ + logger.info(f"Execute command as iterator for command {self._proto_to_string(command)}") + req = self._execute_plan_request_with_metadata() + if self._user_id: + req.user_context.user_id = self._user_id + req.plan.command.CopyFrom(command) + for response in self._execute_and_fetch_as_iterator(req, observations or {}): + if isinstance(response, dict): + yield response + else: + raise PySparkValueError( + error_class="UNKNOWN_RESPONSE", + message_parameters={ + "response": str(response), + }, + ) + def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool: """ return if two plans have the same semantics. @@ -1330,6 +1352,9 @@ class SparkConnectClient(object): if b.HasField("streaming_query_manager_command_result"): cmd_result = b.streaming_query_manager_command_result yield {"streaming_query_manager_command_result": cmd_result} + if b.HasField("streaming_query_listener_events_result"): + event_result = b.streaming_query_listener_events_result + yield {"streaming_query_listener_events_result": event_result} if b.HasField("get_resources_command_result"): resources = {} for key, resource in b.get_resources_command_result.resources.items(): diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index c1940921c631..0624f8943ac4 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -20,15 +20,20 @@ check_dependencies(__name__) import json import sys -import pickle -from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional +import warnings +from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional, Union, Iterator +from threading import Thread, Lock from pyspark.errors import StreamingQueryException, PySparkValueError import pyspark.sql.connect.proto as pb2 -from pyspark.serializers import CloudPickleSerializer from pyspark.sql.connect import proto -from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.streaming import StreamingQueryListener +from pyspark.sql.streaming.listener import ( + QueryStartedEvent, + QueryProgressEvent, + QueryIdleEvent, + QueryTerminatedEvent, +) from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, StreamingQueryManager as PySparkStreamingQueryManager, @@ -36,7 +41,6 @@ from pyspark.sql.streaming.query import ( from pyspark.errors.exceptions.connect import ( StreamingQueryException as CapturedStreamingQueryException, ) -from pyspark.errors import PySparkPicklingError if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession @@ -184,6 +188,7 @@ class StreamingQuery: class StreamingQueryManager: def __init__(self, session: "SparkSession") -> None: self._session = session + self._sqlb = StreamingQueryListenerBus(self) @property def active(self) -> List[StreamingQuery]: @@ -237,27 +242,13 @@ class StreamingQueryManager: resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__ def addListener(self, listener: StreamingQueryListener) -> None: - listener._init_listener_id() - cmd = pb2.StreamingQueryManagerCommand() - expr = proto.PythonUDF() - try: - expr.command = CloudPickleSerializer().dumps(listener) - except pickle.PicklingError: - raise PySparkPicklingError( - error_class="STREAMING_CONNECT_SERIALIZATION_ERROR", - message_parameters={"name": "addListener"}, - ) - expr.python_ver = get_python_ver() - cmd.add_listener.python_listener_payload.CopyFrom(expr) - cmd.add_listener.id = listener._id - self._execute_streaming_query_manager_cmd(cmd) + listener._set_spark_session(self._session) + self._sqlb.append(listener) addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__ def removeListener(self, listener: StreamingQueryListener) -> None: - cmd = pb2.StreamingQueryManagerCommand() - cmd.remove_listener.id = listener._id - self._execute_streaming_query_manager_cmd(cmd) + self._sqlb.remove(listener) removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__ @@ -273,6 +264,167 @@ class StreamingQueryManager: ) +class StreamingQueryListenerBus: + """ + A client side listener bus that is responsible for buffering client side listeners, + receive listener events and invoke correct listener call backs. + """ + + def __init__(self, sqm: "StreamingQueryManager") -> None: + self._sqm = sqm + self._listener_bus: List[StreamingQueryListener] = [] + self._execution_thread: Optional[Thread] = None + self._lock = Lock() + + def append(self, listener: StreamingQueryListener) -> None: + """ + Append a listener to the local listener bus. When the added listener is + the first listener, request the server to create the server side listener + and start a thread to handle query events. + """ + with self._lock: + self._listener_bus.append(listener) + + if len(self._listener_bus) == 1: + assert self._execution_thread is None + try: + result_iter = self._register_server_side_listener() + except Exception as e: + warnings.warn( + f"Failed to add the listener because of exception: {e}\n" + f"The listener is not added, please add it again." + ) + self._listener_bus.remove(listener) + return + self._execution_thread = Thread( + target=self._query_event_handler, args=(result_iter,) + ) + self._execution_thread.start() + + def remove(self, listener: StreamingQueryListener) -> None: + """ + Remove the listener from the local listener bus. + + When the listener is not presented in the listener bus, do nothing. + + When the removed listener is the last listener, ask the server to remove + the server side listener. + As a result, the listener handling thread created before + will return after processing remaining listener events. This function blocks until + all events are processed. + """ + with self._lock: + if listener not in self._listener_bus: + return + + if len(self._listener_bus) == 1: + cmd = pb2.StreamingQueryListenerBusCommand() + cmd.remove_listener_bus_listener = True + exec_cmd = pb2.Command() + exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd) + try: + self._sqm._session.client.execute_command(exec_cmd) + except Exception as e: + warnings.warn( + f"Failed to remove the listener because of exception: {e}\n" + f"The listener is not removed, please remove it again." + ) + return + if self._execution_thread is not None: + self._execution_thread.join() + self._execution_thread = None + + self._listener_bus.remove(listener) + + def _register_server_side_listener(self) -> Iterator[Dict[str, Any]]: + """ + Send add listener request to the server, after received confirmation from the server, + start a new thread to handle these events. + """ + cmd = pb2.StreamingQueryListenerBusCommand() + cmd.add_listener_bus_listener = True + exec_cmd = pb2.Command() + exec_cmd.streaming_query_listener_bus_command.CopyFrom(cmd) + result_iter = self._sqm._session.client.execute_command_as_iterator(exec_cmd) + # Main thread should block until received listener_added_success message + for result in result_iter: + response = cast( + pb2.StreamingQueryListenerEventsResult, + result["streaming_query_listener_events_result"], + ) + if response.HasField("listener_bus_listener_added"): + break + return result_iter + + def _query_event_handler(self, iter: Iterator[Dict[str, Any]]) -> None: + """ + Handler function passed to the new thread, if there is any error while receiving + listener events, it means the connection is unstable. In this case, remove all listeners + and tell the user to add back the listeners. + """ + try: + for result in iter: + response = cast( + pb2.StreamingQueryListenerEventsResult, + result["streaming_query_listener_events_result"], + ) + for event in response.events: + deserialized_event = self.deserialize(event) + self.post_to_all(deserialized_event) + + except Exception as e: + warnings.warn( + "StreamingQueryListenerBus Handler thread received exception, all client side " + f"listeners are removed and handler thread is terminated. The error is: {e}" + ) + with self._lock: + self._execution_thread = None + self._listener_bus.clear() + return + + @staticmethod + def deserialize( + event: pb2.StreamingQueryListenerEvent, + ) -> Union["QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent"]: + if event.event_type == proto.StreamingQueryEventType.QUERY_PROGRESS_EVENT: + return QueryProgressEvent.fromJson(json.loads(event.event_json)) + elif event.event_type == proto.StreamingQueryEventType.QUERY_TERMINATED_EVENT: + return QueryTerminatedEvent.fromJson(json.loads(event.event_json)) + elif event.event_type == proto.StreamingQueryEventType.QUERY_IDLE_EVENT: + return QueryIdleEvent.fromJson(json.loads(event.event_json)) + else: + raise PySparkValueError( + error_class="UNKNOWN_VALUE_FOR", + message_parameters={"var": f"proto.StreamingQueryEventType: {event.event_type}"}, + ) + + def post_to_all( + self, + event: Union[ + "QueryStartedEvent", "QueryProgressEvent", "QueryIdleEvent", "QueryTerminatedEvent" + ], + ) -> None: + """ + Post listener events to all active listeners, note that if one listener throws, + it should not affect other listeners. + """ + with self._lock: + for listener in self._listener_bus: + try: + if isinstance(event, QueryStartedEvent): + listener.onQueryStarted(event) + elif isinstance(event, QueryProgressEvent): + listener.onQueryProgress(event) + elif isinstance(event, QueryIdleEvent): + listener.onQueryIdle(event) + elif isinstance(event, QueryTerminatedEvent): + listener.onQueryTerminated(event) + else: + warnings.warn(f"Unknown StreamingQueryListener event: {event}") + except Exception as e: + warnings.warn(f"Listener {str(listener)} threw an exception\n{e}") + + def _test() -> None: import doctest import os diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index ac0aca6d4b19..4973bb5b6cf7 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -18,6 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) +import json import sys import pickle from typing import cast, overload, Callable, Dict, List, Optional, TYPE_CHECKING, Union @@ -31,6 +32,7 @@ from pyspark.sql.streaming.readwriter import ( DataStreamReader as PySparkDataStreamReader, DataStreamWriter as PySparkDataStreamWriter, ) +from pyspark.sql.streaming.listener import QueryStartedEvent from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.types import Row, StructType from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkPicklingError @@ -599,13 +601,21 @@ class DataStreamWriter: start_result = cast( pb2.WriteStreamOperationStartResult, properties["write_stream_operation_start_result"] ) - return StreamingQuery( + query = StreamingQuery( session=self._session, queryId=start_result.query_id.id, runId=start_result.query_id.run_id, name=start_result.name, ) + if start_result.HasField("query_started_event_json"): + start_event = QueryStartedEvent.fromJson( + json.loads(start_result.query_started_event_json) + ) + self._session.streams._sqlb.post_to_all(start_event) + + return query + def start( self, path: Optional[str] = None, diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index a15e4547f67a..be8c30c28ce0 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -25,9 +25,10 @@ from pyspark.sql.functions import count, lit from pyspark.testing.connectutils import ReusedConnectTestCase +# Listeners that has spark commands in callback handler functions # V1: Initial interface of StreamingQueryListener containing methods `onQueryStarted`, # `onQueryProgress`, `onQueryTerminated`. It is prior to Spark 3.5. -class TestListenerV1(StreamingQueryListener): +class TestListenerSparkV1(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) @@ -45,7 +46,7 @@ class TestListenerV1(StreamingQueryListener): # V2: The interface after the method `onQueryIdle` is added. It is Spark 3.5+. -class TestListenerV2(StreamingQueryListener): +class TestListenerSparkV2(StreamingQueryListener): def onQueryStarted(self, event): e = pyspark.cloudpickle.dumps(event) df = self.spark.createDataFrame(data=[(e,)]) @@ -65,8 +66,140 @@ class TestListenerV2(StreamingQueryListener): df.write.mode("append").saveAsTable("listener_terminated_events_v2") +class TestListenerLocal(StreamingQueryListener): + def __init__(self): + self.start = [] + self.progress = [] + self.terminated = [] + + def onQueryStarted(self, event): + self.start.append(event) + + def onQueryProgress(self, event): + self.progress.append(event) + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + self.terminated.append(event) + + class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): - def test_listener_events(self): + def test_listener_management(self): + listener1 = TestListenerLocal() + listener2 = TestListenerLocal() + + try: + self.spark.streams.addListener(listener1) + self.spark.streams.addListener(listener2) + q = self.spark.readStream.format("rate").load().writeStream.format("noop").start() + + # Both listeners should have listener events already because onQueryStarted + # is always called before DataStreamWriter.start() returns + self.assertEqual(len(listener1.start), 1) + self.assertEqual(len(listener2.start), 1) + + # removeListener is a blocking call, resources are cleaned up by the time it returns + self.spark.streams.removeListener(listener1) + self.spark.streams.removeListener(listener2) + + # Add back the listener and stop the query, now should see a terminated event + self.spark.streams.addListener(listener1) + q.stop() + + # need to wait a while before QueryTerminatedEvent reaches client + time.sleep(15) + self.assertEqual(len(listener1.terminated), 1) + + self.check_start_event(listener1.start[0]) + for event in listener1.progress: + self.check_progress_event(event) + self.check_terminated_event(listener1.terminated[0]) + + finally: + for listener in self.spark.streams._sqlb._listener_bus: + self.spark.streams.removeListener(listener) + for q in self.spark.streams.active: + q.stop() + + def test_slow_query(self): + try: + listener = TestListenerLocal() + self.spark.streams.addListener(listener) + + slow_query = ( + self.spark.readStream.format("rate") + .load() + .writeStream.format("noop") + .trigger(processingTime="20 seconds") + .start() + ) + fast_query = ( + self.spark.readStream.format("rate").load().writeStream.format("noop").start() + ) + + while slow_query.lastProgress is None: + slow_query.awaitTermination(20) + + slow_query.stop() + fast_query.stop() + + self.assertTrue(slow_query.id in [str(e.id) for e in listener.start]) + self.assertTrue(fast_query.id in [str(e.id) for e in listener.start]) + + self.assertTrue(slow_query.id in [str(e.progress.id) for e in listener.progress]) + self.assertTrue(fast_query.id in [str(e.progress.id) for e in listener.progress]) + + self.assertTrue(slow_query.id in [str(e.id) for e in listener.terminated]) + self.assertTrue(fast_query.id in [str(e.id) for e in listener.terminated]) + + finally: + for listener in self.spark.streams._sqlb._listener_bus: + self.spark.streams.removeListener(listener) + for q in self.spark.streams.active: + q.stop() + + def test_listener_throw(self): + """ + Following Vanilla Spark's behavior, when the callback of user-defined listener throws, + other listeners should still proceed. + """ + + class UselessListener(StreamingQueryListener): + def onQueryStarted(self, e): + raise Exception("My bad!") + + def onQueryProgress(self, e): + raise Exception("My bad again!") + + def onQueryTerminated(self, e): + raise Exception("I'm so sorry!") + + try: + listener_good = TestListenerLocal() + listener_bad = UselessListener() + self.spark.streams.addListener(listener_good) + self.spark.streams.addListener(listener_bad) + + q = self.spark.readStream.format("rate").load().writeStream.format("noop").start() + + while q.lastProgress is None: + q.awaitTermination(0.5) + + q.stop() + # need to wait a while before QueryTerminatedEvent reaches client + time.sleep(5) + self.assertTrue(len(listener_good.start) > 0) + self.assertTrue(len(listener_good.progress) > 0) + self.assertTrue(len(listener_good.terminated) > 0) + finally: + for listener in self.spark.streams._sqlb._listener_bus: + self.spark.streams.removeListener(listener) + for q in self.spark.streams.active: + q.stop() + + def test_listener_events_spark_command(self): def verify(test_listener, table_postfix): try: self.spark.streams.addListener(test_listener) @@ -88,7 +221,7 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes self.assertTrue(q.isActive) # ensure at least one batch is ran while q.lastProgress is None or q.lastProgress["batchId"] == 0: - time.sleep(5) + q.awaitTermination(5) q.stop() self.assertFalse(q.isActive) @@ -129,46 +262,8 @@ class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes "listener_progress_events_v2", "listener_terminated_events_v2", ): - verify(TestListenerV1(), "_v1") - verify(TestListenerV2(), "_v2") - - def test_accessing_spark_session(self): - spark = self.spark - - class TestListener(StreamingQueryListener): - def onQueryStarted(self, event): - spark.createDataFrame( - [("you", "can"), ("serialize", "spark")] - ).createOrReplaceTempView("test_accessing_spark_session") - - def onQueryProgress(self, event): - pass - - def onQueryIdle(self, event): - pass - - def onQueryTerminated(self, event): - pass - - self.spark.streams.addListener(TestListener()) - - def test_accessing_spark_session_through_df(self): - dataframe = self.spark.createDataFrame([("you", "can"), ("serialize", "dataframe")]) - - class TestListener(StreamingQueryListener): - def onQueryStarted(self, event): - dataframe.collect() - - def onQueryProgress(self, event): - pass - - def onQueryIdle(self, event): - pass - - def onQueryTerminated(self, event): - pass - - self.spark.streams.addListener(TestListener()) + verify(TestListenerSparkV1(), "_v1") + verify(TestListenerSparkV2(), "_v2") if __name__ == "__main__": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org