This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f086c2327c36 [SPARK-47174][CONNECT][SS][1/2] Server side
SparkConnectListenerBusListener for Client side streaming query listener
f086c2327c36 is described below
commit f086c2327c36c396ae5d886afd3ef613650c6b0d
Author: Wei Liu <[email protected]>
AuthorDate: Fri Apr 12 10:08:45 2024 +0900
[SPARK-47174][CONNECT][SS][1/2] Server side SparkConnectListenerBusListener
for Client side streaming query listener
### What changes were proposed in this pull request?
Server side `SparkConnectListenerBusListener` implementation for the client
side listener. There would only be one such listener for each `SessionHolder`.
### Why are the changes needed?
Move streaming query listener to client side
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45988 from WweiL/SPARK-47174-client-side-listener-1.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 34 ++-
...SparkConnectStreamingQueryListenerHandler.scala | 121 +++++++++++
.../spark/sql/connect/service/SessionHolder.scala | 7 +
.../service/SparkConnectListenerBusListener.scala | 156 ++++++++++++++
.../SparkConnectListenerBusListenerSuite.scala | 240 +++++++++++++++++++++
5 files changed, 555 insertions(+), 3 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 96db45c5c63e..5e7f3b74c299 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2551,6 +2551,11 @@ class SparkConnectPlanner(
handleStreamingQueryManagerCommand(
command.getStreamingQueryManagerCommand,
responseObserver)
+ case proto.Command.CommandTypeCase.STREAMING_QUERY_LISTENER_BUS_COMMAND
=>
+ val handler = new
SparkConnectStreamingQueryListenerHandler(executeHolder)
+ handler.handleListenerCommand(
+ command.getStreamingQueryListenerBusCommand,
+ responseObserver)
case proto.Command.CommandTypeCase.GET_RESOURCES_COMMAND =>
handleGetResourcesCommand(responseObserver)
case proto.Command.CommandTypeCase.CREATE_RESOURCE_PROFILE_COMMAND =>
@@ -3118,7 +3123,7 @@ class SparkConnectPlanner(
}
executeHolder.eventsManager.postFinished()
- val result = WriteStreamOperationStartResult
+ val resultBuilder = WriteStreamOperationStartResult
.newBuilder()
.setQueryId(
StreamingQueryInstanceId
@@ -3127,14 +3132,37 @@ class SparkConnectPlanner(
.setRunId(query.runId.toString)
.build())
.setName(Option(query.name).getOrElse(""))
- .build()
+
+ // The query started event for this query is sent to the client, and is
handled by
+ // the client side listeners before client's DataStreamWriter.start()
returns.
+ // This is to ensure that the onQueryStarted call back is called before
the start() call, which
+ // is defined in the onQueryStarted API.
+ // So the flow is:
+ // 1. On the server side, the query is started above.
+ // 2. Per the contract of the onQueryStarted API, the queryStartedEvent is
added to the
+ // streamingServersideListenerHolder.streamingQueryStartedEventCache,
by the onQueryStarted
+ // call back of
streamingServersideListenerHolder.streamingQueryServerSideListener.
+ // 3. The queryStartedEvent is sent to the client.
+ // 4. The client side listener handles the queryStartedEvent and calls the
onQueryStarted API,
+ // before the client side DataStreamWriter.start().
+ // This way we ensure that the onQueryStarted API is called before the
start() call in Connect.
+ val queryStartedEvent = Option(
+
sessionHolder.streamingServersideListenerHolder.streamingQueryStartedEventCache.remove(
+ query.runId.toString))
+ queryStartedEvent.foreach {
+ logDebug(
+ s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}][query id: ${query.id}][query runId:
${query.runId}] " +
+ s"Adding QueryStartedEvent to response")
+ e => resultBuilder.setQueryStartedEventJson(e.json)
+ }
responseObserver.onNext(
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
- .setWriteStreamOperationStartResult(result)
+ .setWriteStreamOperationStartResult(resultBuilder.build())
.build())
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
new file mode 100644
index 000000000000..94f01026b7a5
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.planner
+
+import scala.util.control.NonFatal
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.connect.proto.StreamingQueryListenerBusCommand
+import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.service.ExecuteHolder
+
+/**
+ * Handle long-running streaming query listener events.
+ */
+class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder)
extends Logging {
+
+ val sessionHolder = executeHolder.sessionHolder
+
+ private[connect] def userId: String = sessionHolder.userId
+
+ private[connect] def sessionId: String = sessionHolder.sessionId
+
+ /**
+ * The handler logic. The handler of ADD_LISTENER_BUS_LISTENER uses the
+ * streamingQueryListenerLatch to block the handling thread, preventing it
from sending back the
+ * final ResultComplete response.
+ *
+ * The handler of REMOVE_LISTENER_BUS_LISTENER cleans up the server side
listener resources and
+ * count down the latch, allowing the handling thread of the original
ADD_LISTENER_BUS_LISTENER
+ * to proceed to send back the final ResultComplete response.
+ */
+ def handleListenerCommand(
+ command: StreamingQueryListenerBusCommand,
+ responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+
+ val listenerHolder = sessionHolder.streamingServersideListenerHolder
+
+ command.getCommandCase match {
+ case
StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER =>
+ listenerHolder.isServerSideListenerRegistered match {
+ case true =>
+ logWarning(
+ s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}] Redundant server side listener
added. Exiting.")
+ return
+ case false =>
+ // This transfers sending back the response to the client until
+ // the long running command is terminated, either by
+ // errors in streamingQueryServerSideListener.send,
+ // or client issues a REMOVE_LISTENER_BUS_LISTENER call.
+ listenerHolder.init(responseObserver)
+ // Send back listener added response
+ val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
+ val listenerAddedResult = respBuilder
+ .setListenerBusListenerAdded(true)
+ .build()
+ try {
+ responseObserver.onNext(
+ ExecutePlanResponse
+ .newBuilder()
+ .setSessionId(sessionHolder.sessionId)
+ .setServerSideSessionId(sessionHolder.serverSessionId)
+ .setStreamingQueryListenerEventsResult(listenerAddedResult)
+ .build())
+ } catch {
+ case NonFatal(e) =>
+ logError(
+ s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}] Error sending listener
added response.",
+ e)
+ listenerHolder.cleanUp()
+ return
+ }
+ }
+ logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}] Server side listener added. Now
blocking until " +
+ "all client side listeners are removed or there is error
transmitting the event back.")
+ // Block the handling thread, and have serverListener continuously
send back new events
+ listenerHolder.streamingQueryListenerLatch.await()
+ logInfo(s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}] Server side listener long-running
handling thread ended.")
+ case
StreamingQueryListenerBusCommand.CommandCase.REMOVE_LISTENER_BUS_LISTENER =>
+ listenerHolder.isServerSideListenerRegistered match {
+ case true =>
+ sessionHolder.streamingServersideListenerHolder.cleanUp()
+ case false =>
+ logWarning(
+ s"[SessionId: $sessionId][UserId: $userId][operationId: " +
+ s"${executeHolder.operationId}] No active server side listener
bus listener " +
+ s"but received remove listener call. Exiting.")
+ return
+ }
+ case StreamingQueryListenerBusCommand.CommandCase.COMMAND_NOT_SET =>
+ throw new IllegalArgumentException("Missing command in
StreamingQueryListenerBusCommand")
+ }
+ // If this thread is the handling thread of the original
ADD_LISTENER_BUS_LISTENER command,
+ // this will be sent when the latch is counted down (either through
+ // a REMOVE_LISTENER_BUS_LISTENER command, or long-lived gRPC throws.
+ // If this thread is the handling thread of the
REMOVE_LISTENER_BUS_LISTENER command,
+ // this is hit right away.
+ executeHolder.eventsManager.postFinished()
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index ef79cdcce8ff..306b89148583 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -92,6 +92,8 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
new StreamingForeachBatchHelper.CleanerCache(this)
+ private[connect] lazy val streamingServersideListenerHolder = new
ServerSideListenerHolder(this)
+
def key: SessionKey = SessionKey(userId, sessionId)
// Returns the server side session ID and asserts that it must be different
from the client-side
@@ -267,6 +269,11 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any
streaming workers.
removeAllListeners() // removes all listener and stop python listener
processes if necessary.
+ // if there is a server side listener, clean up related resources
+ if (streamingServersideListenerHolder.isServerSideListenerRegistered) {
+ streamingServersideListenerHolder.cleanUp()
+ }
+
// Clean up all executions.
// After closedTimeMs is defined, SessionHolder.addExecuteHolder() will
not allow new executions
// to be added for this session anymore. Because both
SessionHolder.addExecuteHolder() and
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
new file mode 100644
index 000000000000..1b6c5179871d
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
+
+import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto.ExecutePlanResponse
+import org.apache.spark.connect.proto.StreamingQueryEventType
+import org.apache.spark.connect.proto.StreamingQueryListenerEvent
+import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.streaming.StreamingQueryListener
+import org.apache.spark.util.ArrayImplicits._
+
+/**
+ * A holder for the server side listener and related resources. There should
be only one such
+ * holder for each sessionHolder.
+ */
+private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
+ // The server side listener that is responsible to stream streaming query
events back to client.
+ // There is only one listener per sessionHolder, but each listener is
responsible for all events
+ // of all streaming queries in the SparkSession.
+ var streamingQueryServerSideListener:
Option[SparkConnectListenerBusListener] = None
+ // The count down latch to hold the long-running listener thread before
sending ResultComplete.
+ var streamingQueryListenerLatch = new CountDownLatch(1)
+ // The cache for QueryStartedEvent, key is query runId and value is the
actual QueryStartedEvent.
+ // Events for corresponding query will be sent back to client with
+ // the WriteStreamOperationStart response, so that the client can handle the
event before
+ // DataStreamWriter.start() returns. This special handling is to satisfy the
contract of
+ // onQueryStarted in StreamingQueryListener.
+ val streamingQueryStartedEventCache
+ : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new
ConcurrentHashMap()
+
+ def isServerSideListenerRegistered: Boolean =
streamingQueryServerSideListener.isDefined
+
+ /**
+ * The initialization of the server side listener and related resources.
This method is called
+ * when the first ADD_LISTENER_BUS_LISTENER command is received. It is
attached to a
+ * responseObserver, from the first executeThread (long running thread), so
the lifecycle of the
+ * responseObserver is the same as the life cycle of the listener.
+ *
+ * @param responseObserver
+ * the responseObserver created from the first long running executeThread.
+ */
+ def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+ val serverListener = new SparkConnectListenerBusListener(this,
responseObserver)
+ sessionHolder.session.streams.addListener(serverListener)
+ streamingQueryServerSideListener = Some(serverListener)
+ streamingQueryListenerLatch = new CountDownLatch(1)
+ }
+
+ /**
+ * The cleanup of the server side listener and related resources. This
method is called when the
+ * REMOVE_LISTENER_BUS_LISTENER command is received or when
responseObserver.onNext throws an
+ * exception. It removes the listener from the session, clears the cache.
Also it counts down
+ * the latch, so the long-running thread can proceed to send back the final
ResultComplete
+ * response.
+ */
+ def cleanUp(): Unit = {
+ streamingQueryServerSideListener.foreach { listener =>
+ sessionHolder.session.streams.removeListener(listener)
+ }
+ streamingQueryStartedEventCache.clear()
+ streamingQueryServerSideListener = None
+ streamingQueryListenerLatch.countDown()
+ }
+}
+
+/**
+ * A customized StreamingQueryListener used in Spark Connect for the
client-side listeners. Upon
+ * the invocation of each callback function, it serializes the event to json
and sent it to the
+ * client.
+ */
+private[sql] class SparkConnectListenerBusListener(
+ serverSideListenerHolder: ServerSideListenerHolder,
+ responseObserver: StreamObserver[ExecutePlanResponse])
+ extends StreamingQueryListener
+ with Logging {
+
+ val sessionHolder = serverSideListenerHolder.sessionHolder
+ // The method used to stream back the events to the client.
+ // The event is serialized to json and sent to the client.
+ // The responseObserver is what of the first executeThread (long running
thread),
+ // which is held still by the streamingQueryListenerLatch.
+ // If any exception is thrown while transmitting back the event, the
listener is removed,
+ // all related sources are cleaned up, and the long-running thread will
proceed to send
+ // the final ResultComplete response.
+ private def send(eventJson: String, eventType: StreamingQueryEventType):
Unit = {
+ val event = StreamingQueryListenerEvent
+ .newBuilder()
+ .setEventJson(eventJson)
+ .setEventType(eventType)
+ .build()
+
+ val respBuilder = StreamingQueryListenerEventsResult.newBuilder()
+ val eventResult = respBuilder
+
.addAllEvents(Array[StreamingQueryListenerEvent](event).toImmutableArraySeq.asJava)
+ .build()
+
+ try {
+ responseObserver.onNext(
+ ExecutePlanResponse
+ .newBuilder()
+ .setSessionId(sessionHolder.sessionId)
+ .setServerSideSessionId(sessionHolder.serverSessionId)
+ .setStreamingQueryListenerEventsResult(eventResult)
+ .build())
+ } catch {
+ case NonFatal(e) =>
+ logError(
+ s"[SessionId: ${sessionHolder.sessionId}][UserId:
${sessionHolder.userId}] " +
+ s"Removing SparkConnectListenerBusListener and terminating the
long-running thread " +
+ s"because of exception: $e")
+ // This likely means that the client is not responsive even with
retry, we should
+ // remove this listener and cleanup resources.
+ serverSideListenerHolder.cleanUp()
+ }
+ }
+
+ // QueryStartedEvent is sent to client along with
WriteStreamOperationStartResult
+ override def onQueryStarted(event:
StreamingQueryListener.QueryStartedEvent): Unit = {
+
serverSideListenerHolder.streamingQueryStartedEventCache.put(event.runId.toString,
event)
+ }
+
+ override def onQueryProgress(event:
StreamingQueryListener.QueryProgressEvent): Unit = {
+ send(event.json, StreamingQueryEventType.QUERY_PROGRESS_EVENT)
+ }
+
+ override def onQueryTerminated(event:
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+ send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
+ }
+
+ override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit
= {
+ send(event.json, StreamingQueryEventType.QUERY_IDLE_EVENT)
+ }
+}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
new file mode 100644
index 000000000000..4c2962fda507
--- /dev/null
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration.DurationInt
+import scala.jdk.CollectionConverters._
+
+import io.grpc.stub.StreamObserver
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.connect.proto.{Command, ExecutePlanResponse}
+import
org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
+import org.apache.spark.sql.streaming.Trigger.ProcessingTime
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SparkConnectListenerBusListenerSuite
+ extends SparkFunSuite
+ with SharedSparkSession
+ with MockitoSugar {
+
+ override def afterEach(): Unit = {
+ try {
+ spark.streams.active.foreach(_.stop())
+ spark.streams.listListeners().foreach(spark.streams.removeListener)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ // A test listener that caches all events
+ private class CacheEventsStreamingQueryListener(
+ startEvents: ArrayBuffer[StreamingQueryListener.QueryStartedEvent],
+ otherEvents: ArrayBuffer[StreamingQueryListener.Event])
+ extends StreamingQueryListener {
+
+ override def onQueryStarted(event:
StreamingQueryListener.QueryStartedEvent): Unit = {
+ startEvents += event
+ }
+
+ override def onQueryProgress(event:
StreamingQueryListener.QueryProgressEvent): Unit = {
+ otherEvents += event
+ }
+
+ override def onQueryTerminated(event:
StreamingQueryListener.QueryTerminatedEvent): Unit = {
+ otherEvents += event
+ }
+
+ override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent):
Unit = {
+ otherEvents += event
+ }
+ }
+
+ private def verifyEventsSent(
+ fromCachedEventsListener: ArrayBuffer[StreamingQueryListener.Event],
+ fromListenerBusListener: ArrayBuffer[String]): Unit = {
+ assert(fromListenerBusListener.toSet === fromCachedEventsListener.map {
+ case e: StreamingQueryListener.QueryStartedEvent => e.json
+ case e: StreamingQueryListener.QueryProgressEvent => e.json
+ case e: StreamingQueryListener.QueryTerminatedEvent => e.json
+ case e: StreamingQueryListener.QueryIdleEvent => e.json
+ }.toSet)
+ }
+
+ private def startQuery(slow: Boolean = false): StreamingQuery = {
+ val dsw = spark.readStream.format("rate").load().writeStream.format("noop")
+ if (slow) {
+ dsw.trigger(ProcessingTime("20 seconds"))
+ }
+ dsw.start()
+ }
+
+ Seq(1, 5, 20).foreach { queryNum =>
+ test(
+ "Basic functionalities - onQueryStart, onQueryProgress,
onQueryTerminated" +
+ s" - $queryNum queries") {
+ val sessionHolder = SessionHolder.forTesting(spark)
+ val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+ val eventJsonBuffer = ArrayBuffer.empty[String]
+ val startEventsBuffer =
ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent]
+ val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event]
+
+ doAnswer((invocation: InvocationOnMock) => {
+ val argument = invocation.getArgument[ExecutePlanResponse](0)
+ val eventJson =
argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson
+ eventJsonBuffer += eventJson
+ }).when(responseObserver).onNext(any[ExecutePlanResponse]())
+
+ val listenerHolder = sessionHolder.streamingServersideListenerHolder
+ listenerHolder.init(responseObserver)
+ val cachedEventsListener =
+ new CacheEventsStreamingQueryListener(startEventsBuffer,
otherEventsBuffer)
+
+ spark.streams.addListener(cachedEventsListener)
+
+ for (_ <- 1 to queryNum) startQuery()
+
+ // after all queries made some progresses
+ eventually(timeout(60.seconds), interval(2.seconds)) {
+ spark.streams.active.foreach { q =>
+ assert(q.lastProgress.batchId > 5)
+ }
+ }
+
+ // stops all queries
+ spark.streams.active.foreach(_.stop())
+
+ eventually(timeout(60.seconds), interval(500.milliseconds)) {
+ assert(eventJsonBuffer.nonEmpty)
+ assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty)
+ verifyEventsSent(otherEventsBuffer, eventJsonBuffer)
+ assert(
+ startEventsBuffer.map(_.json).toSet ===
+
listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet)
+ }
+ }
+ }
+
+ test("Basic functionalities - Slow query") {
+ val sessionHolder = SessionHolder.forTesting(spark)
+ val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+ val eventJsonBuffer = ArrayBuffer.empty[String]
+ val startEventsBuffer =
ArrayBuffer.empty[StreamingQueryListener.QueryStartedEvent]
+ val otherEventsBuffer = ArrayBuffer.empty[StreamingQueryListener.Event]
+
+ doAnswer((invocation: InvocationOnMock) => {
+ val argument = invocation.getArgument[ExecutePlanResponse](0)
+ val eventJson =
argument.getStreamingQueryListenerEventsResult().getEvents(0).getEventJson
+ eventJsonBuffer += eventJson
+ }).when(responseObserver).onNext(any[ExecutePlanResponse]())
+
+ val listenerHolder = sessionHolder.streamingServersideListenerHolder
+ listenerHolder.init(responseObserver)
+
+ val cachedEventsListener =
+ new CacheEventsStreamingQueryListener(startEventsBuffer,
otherEventsBuffer)
+ spark.streams.addListener(cachedEventsListener)
+
+ // Slow query
+ val q = startQuery(true)
+
+ // after the slow query made some progresses
+ eventually(timeout(100.seconds), interval(7.seconds)) {
+ assert(q.lastProgress.batchId > 2)
+ }
+
+ q.stop()
+
+ eventually(timeout(60.seconds), interval(1.second)) {
+ assert(eventJsonBuffer.nonEmpty)
+ assert(!listenerHolder.streamingQueryStartedEventCache.isEmpty)
+ verifyEventsSent(otherEventsBuffer, eventJsonBuffer)
+ assert(
+ startEventsBuffer.map(_.json).toSet ===
+
listenerHolder.streamingQueryStartedEventCache.asScala.map(_._2.json).toSet)
+ }
+ }
+
+ test("Proper handling on onNext throw - initial response") {
+ val sessionHolder = SessionHolder.forTesting(spark)
+
+ val executeHolder = mock[ExecuteHolder]
+ when(executeHolder.sessionHolder).thenReturn(sessionHolder)
+ when(executeHolder.operationId).thenReturn("operationId")
+
+ val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+ doThrow(new RuntimeException("I'm dead"))
+ .when(responseObserver)
+ .onNext(any[ExecutePlanResponse]())
+
+ val listenerCntBeforeThrow = spark.streams.listListeners().size
+
+ val handler = new SparkConnectStreamingQueryListenerHandler(executeHolder)
+ val listenerBusCmdBuilder =
Command.newBuilder().getStreamingQueryListenerBusCommandBuilder
+ val addListenerCommand =
listenerBusCmdBuilder.setAddListenerBusListener(true).build()
+ handler.handleListenerCommand(addListenerCommand, responseObserver)
+
+ val listenerHolder = sessionHolder.streamingServersideListenerHolder
+ eventually(timeout(5.seconds), interval(500.milliseconds)) {
+ assert(
+
sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty)
+ assert(spark.streams.listListeners().size === listenerCntBeforeThrow)
+ assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
+ assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
+ }
+
+ }
+
+ test("Proper handling on onNext throw - query progress") {
+ val sessionHolder = SessionHolder.forTesting(spark)
+ val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
+ doThrow(new RuntimeException("I'm dead"))
+ .when(responseObserver)
+ .onNext(any[ExecutePlanResponse]())
+
+ val listenerHolder = sessionHolder.streamingServersideListenerHolder
+ listenerHolder.init(responseObserver)
+ val listenerBusListener =
listenerHolder.streamingQueryServerSideListener.get
+
+ // mock a QueryStartedEvent for cleanup test
+ val queryStartedEvent = new StreamingQueryListener.QueryStartedEvent(
+ UUID.randomUUID,
+ UUID.randomUUID,
+ "name",
+ "timestamp")
+ listenerHolder.streamingQueryStartedEventCache.put(
+ queryStartedEvent.runId.toString,
+ queryStartedEvent)
+
+ startQuery()
+
+ eventually(timeout(5.seconds), interval(500.milliseconds)) {
+ assert(!spark.streams.listListeners().contains(listenerBusListener))
+ assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
+ assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]