This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a83eec8235cc [SPARK-55691][CONNECT] GetStatus client
a83eec8235cc is described below
commit a83eec8235cc9241687b483524927f545dc84596
Author: Anastasiia Terenteva <[email protected]>
AuthorDate: Fri Feb 27 10:57:22 2026 -0400
[SPARK-55691][CONNECT] GetStatus client
### What changes were proposed in this pull request?
Client-side implementation of the GetStatus API:
- Add client methods for requesting operation statuses.
- GetStatus API is designed to support other request types that operation
statuses in future. For those types, new methods will have to be created. This
is a tradeoff between convenient client methods which encapsulate the most of
the request building and flexibility.
- Mark the client methods as Experimental and underscored for now, until
enough confidence in the API stability is built.
- Add client tests with mocked service.
- Add E2E tests with real operation lifecycles.
### Why are the changes needed?
GetStatus API allows to monitor status of executions in a session, which is
particularly useful in multithreaded clients.
### Does this PR introduce _any_ user-facing change?
Yes. It's a new Spark Connect API.
### How was this patch tested?
- Add client tests with mocked service.
- Add E2E tests with real operation lifecycles, which also test server-side
changes.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: Claude 4.6 Opus High
Closes #54520 from terana/get-status-client.
Authored-by: Anastasiia Terenteva <[email protected]>
Signed-off-by: Herman van Hövell <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 53 +++++
.../sql/tests/connect/client/test_client.py | 174 ++++++++++++++++-
.../connect/client/SparkConnectClientSuite.scala | 164 ++++++++++++++++
.../client/CustomSparkConnectBlockingStub.scala | 13 ++
.../sql/connect/client/SparkConnectClient.scala | 39 +++-
.../spark/sql/connect/SparkConnectServerTest.scala | 33 ++--
.../connect/service/GetStatusHandlerE2ESuite.scala | 216 +++++++++++++++++++++
7 files changed, 673 insertions(+), 19 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index f28d97e7472e..aa060df24e41 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1930,6 +1930,58 @@ class SparkConnectClient(object):
except Exception as error:
self._handle_error(error)
+ def _get_operation_statuses(
+ self,
+ operation_ids: Optional[List[str]] = None,
+ operation_extensions: Optional[List[any_pb2.Any]] = None,
+ request_extensions: Optional[List[any_pb2.Any]] = None,
+ ) -> "pb2.GetStatusResponse":
+ """
+ Get status of operations in the session.
+
+ Parameters
+ ----------
+ operation_ids : list of str, optional
+ List of operation IDs to get status for.
+ If None or empty, returns status of all operations in the session.
+ operation_extensions : list of google.protobuf.any_pb2.Any, optional
+ Per-operation extension messages to include in the
OperationStatusRequest to request
+ additional per-operation information.
+ request_extensions : list of google.protobuf.any_pb2.Any, optional
+ Request-level extension messages to include in the
GetStatusRequest.
+
+ Returns
+ -------
+ pb2.GetStatusResponse
+ The full GetStatusResponse, including operation_statuses and any
extensions.
+ """
+ req = pb2.GetStatusRequest()
+ req.session_id = self._session_id
+ req.client_type = self._builder.userAgent
+ if self._user_id:
+ req.user_context.user_id = self._user_id
+ if self._server_session_id:
+ req.client_observed_server_side_session_id =
self._server_session_id
+
+ req.operation_status.SetInParent()
+
+ if operation_ids:
+ req.operation_status.operation_ids.extend(operation_ids)
+ if operation_extensions:
+ req.operation_status.extensions.extend(operation_extensions)
+ if request_extensions:
+ req.extensions.extend(request_extensions)
+
+ try:
+ for attempt in self._retrying():
+ with attempt:
+ resp = self._stub.GetStatus(req,
metadata=self._builder.metadata())
+ self._verify_response_integrity(resp)
+ return resp
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+ except Exception as error:
+ self._handle_error(error)
+
def add_tag(self, tag: str) -> None:
self._throw_if_invalid_tag(tag)
if not hasattr(self.thread_local, "tags"):
@@ -2169,6 +2221,7 @@ class SparkConnectClient(object):
pb2.AnalyzePlanResponse,
pb2.FetchErrorDetailsResponse,
pb2.ReleaseSessionResponse,
+ pb2.GetStatusResponse,
],
) -> None:
"""
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index 8daf82ad3f37..55faff5e9ed3 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -134,10 +134,25 @@ if should_test_connect:
req: Optional[proto.ExecutePlanRequest]
- def __init__(self, session_id: str):
+ OperationStatus = proto.GetStatusResponse.OperationStatus
+ DEFAULT_OPERATION_STATUSES = [
+ OperationStatus(
+ operation_id="default-op-1",
+ state=OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+ ),
+ OperationStatus(
+ operation_id="default-op-2",
+ state=OperationStatus.OperationState.OPERATION_STATE_RUNNING,
+ ),
+ ]
+
+ def __init__(self, session_id: str, operation_statuses=None):
self._session_id = session_id
self.req = None
self.client_user_context_extensions = []
+ if operation_statuses is None:
+ operation_statuses = self.DEFAULT_OPERATION_STATUSES
+ self._operation_statuses = {s.operation_id: s for s in
operation_statuses}
def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
@@ -191,6 +206,45 @@ if should_test_connect:
resp.semantic_hash.result = 12345
return resp
+ def GetStatus(self, req: proto.GetStatusRequest, metadata):
+ self.req = req
+ self.client_user_context_extensions =
list(req.user_context.extensions)
+ self.received_custom_server_session_id =
req.client_observed_server_side_session_id
+ resp = proto.GetStatusResponse(session_id=self._session_id)
+
+ # Echo top-level request extensions back in the response
+ if req.extensions:
+ resp.extensions.extend(req.extensions)
+
+ if not req.HasField("operation_status"):
+ return resp
+
+ # Collect operation-status-level extensions from the request to
echo back
+ op_status_extensions = list(req.operation_status.extensions)
+
+ requested_ids = list(req.operation_status.operation_ids)
+ if len(requested_ids) == 0:
+ # Empty list — return all statuses
+
resp.operation_statuses.extend(self._operation_statuses.values())
+ return resp
+
+ OperationStatus = proto.GetStatusResponse.OperationStatus
+ for op_id in requested_ids:
+ status = self._operation_statuses.get(op_id)
+ if status is not None:
+ op_status = OperationStatus(
+ operation_id=status.operation_id,
+ state=status.state,
+ )
+ else:
+ op_status = OperationStatus(
+ operation_id=op_id,
+
state=OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
+ )
+ op_status.extensions.extend(op_status_extensions)
+ resp.operation_statuses.append(op_status)
+ return resp
+
# The _cleanup_ml_cache invocation will hang in this test (no valid spark
cluster)
# and it blocks the test process exiting because it is registered as the
atexit handler
# in `SparkConnectClient` constructor. To bypass the issue, patch the
method in the test.
@@ -396,6 +450,124 @@ class SparkConnectClientTestCase(unittest.TestCase):
for resp in client._stub.ExecutePlan(req, metadata=None):
assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
+ def test_get_operations_statuses_all(self):
+ """Test get_operations_statuses returns all operation statuses when no
IDs specified."""
+ OperationStatus = proto.GetStatusResponse.OperationStatus
+ client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ resp = client._get_operation_statuses()
+ result = list(resp.operation_statuses)
+ self.assertEqual(len(result), 2)
+ status_map = {s.operation_id: s.state for s in result}
+ self.assertEqual(
+ status_map["default-op-1"],
+ OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+ )
+ self.assertEqual(
+ status_map["default-op-2"],
+ OperationStatus.OperationState.OPERATION_STATE_RUNNING,
+ )
+
+ def test_get_operations_statuses_specific_ids(self):
+ """Test get_operations_statuses filters by specific operation IDs."""
+ OperationStatus = proto.GetStatusResponse.OperationStatus
+ client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ resp = client._get_operation_statuses(operation_ids=["default-op-1",
"unknown-op"])
+ result = list(resp.operation_statuses)
+ self.assertEqual(len(result), 2)
+ status_map = {s.operation_id: s.state for s in result}
+ self.assertEqual(
+ status_map["default-op-1"],
+ OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+ )
+ self.assertEqual(
+ status_map["unknown-op"],
+ OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
+ )
+ # Verify the request included the operation IDs
+ self.assertEqual(
+ set(mock.req.operation_status.operation_ids), {"default-op-1",
"unknown-op"}
+ )
+
+ def test_get_operations_statuses_empty(self):
+ """Test get_operations_statuses returns empty list when no operations
exist."""
+ client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
+ mock = MockService(client._session_id, operation_statuses=[])
+ client._stub = mock
+
+ resp = client._get_operation_statuses()
+ self.assertEqual(len(list(resp.operation_statuses)), 0)
+
+ def test_get_operations_statuses_with_operation_extensions(self):
+ """Test get_operations_statuses passes operation-level extensions and
echoes them back per operation."""
+ from google.protobuf import any_pb2, wrappers_pb2
+
+ client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ op_ext = any_pb2.Any()
+ op_ext.Pack(wrappers_pb2.StringValue(value="op_extension"))
+
+ resp = client._get_operation_statuses(
+ operation_ids=["default-op-1", "default-op-2"],
+ operation_extensions=[op_ext],
+ )
+ result = list(resp.operation_statuses)
+ self.assertEqual(len(result), 2)
+ self.assertEqual({s.operation_id for s in result}, {"default-op-1",
"default-op-2"})
+
+ # Verify operation-level extensions were included in the request
+ self.assertEqual(len(mock.req.operation_status.extensions), 1)
+ unpacked = wrappers_pb2.StringValue()
+ mock.req.operation_status.extensions[0].Unpack(unpacked)
+ self.assertEqual(unpacked.value, "op_extension")
+
+ # Verify operation-level extensions were echoed back per operation
+ for op_status in result:
+ self.assertEqual(len(op_status.extensions), 1)
+ echoed = wrappers_pb2.StringValue()
+ op_status.extensions[0].Unpack(echoed)
+ self.assertEqual(echoed.value, "op_extension")
+
+ def test_get_operations_statuses_with_request_extensions(self):
+ """Test _get_operation_statuses sends request-level extensions and
echoes them back."""
+ from google.protobuf import any_pb2, wrappers_pb2
+
+ client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ req_ext = any_pb2.Any()
+ req_ext.Pack(wrappers_pb2.StringValue(value="request_extension"))
+
+ resp = client._get_operation_statuses(
+ operation_ids=["default-op-1"],
+ request_extensions=[req_ext],
+ )
+
+ # Verify the operation status is returned
+ result = list(resp.operation_statuses)
+ self.assertEqual(len(result), 1)
+ self.assertEqual(result[0].operation_id, "default-op-1")
+
+ # Verify request-level extensions were included in the request
+ self.assertEqual(len(mock.req.extensions), 1)
+ unpacked = wrappers_pb2.StringValue()
+ mock.req.extensions[0].Unpack(unpacked)
+ self.assertEqual(unpacked.value, "request_extension")
+
+ # Verify request-level extensions were echoed back in the response
+ self.assertEqual(len(resp.extensions), 1)
+ resp_echoed = wrappers_pb2.StringValue()
+ resp.extensions[0].Unpack(resp_echoed)
+ self.assertEqual(resp_echoed.value, "request_extension")
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectClientReattachTestCase(unittest.TestCase):
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 1cfc2d3eb09f..e833657984dc 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable
import scala.jdk.CollectionConverters._
+import com.google.protobuf.{Any => PAny, StringValue}
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata,
MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor,
Status, StatusRuntimeException}
import io.grpc.netty.NettyServerBuilder
import io.grpc.stub.StreamObserver
@@ -548,6 +549,114 @@ class SparkConnectClientSuite extends ConnectFunSuite {
observer.onCompleted()
}
+ test("getOperationStatuses returns operation statuses for requested IDs") {
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .build()
+
+ val response =
+ client.getOperationStatuses(Seq("default-op-1", "unknown-op"))
+ val statuses = response.getOperationStatusesList.asScala.toSeq
+ assert(statuses.size == 2)
+ assert(statuses.map(_.getOperationId).toSet == Set("default-op-1",
"unknown-op"))
+
+ val statusMap = statuses.map(s => s.getOperationId -> s.getState).toMap
+ assert(
+ statusMap("default-op-1") ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+ assert(
+ statusMap("unknown-op") ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+ }
+
+ test("getOperationStatuses with no IDs returns all operations from server") {
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .build()
+
+ val response = client.getOperationStatuses()
+ val statuses = response.getOperationStatusesList.asScala.toSeq
+ assert(statuses.size == 2)
+ assert(statuses.map(_.getOperationId).toSet == Set("default-op-1",
"default-op-2"))
+
+ val statusMap = statuses.map(s => s.getOperationId -> s.getState).toMap
+ assert(
+ statusMap("default-op-1") ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+ assert(
+ statusMap("default-op-2") ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+ }
+
+ test("getOperationStatuses sends extensions and returns them per operation")
{
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .build()
+
+ val extension = PAny.pack(StringValue.of("custom_extension"))
+
+ val response = client.getOperationStatuses(
+ operationIds = Seq("default-op-1", "default-op-2"),
+ operationExtensions = Seq(extension))
+
+ // Verify operation statuses are returned
+ val statuses = response.getOperationStatusesList.asScala.toSeq
+ assert(statuses.size == 2)
+
+ val statusMap = statuses.map(s => s.getOperationId -> s).toMap
+ assert(
+ statusMap("default-op-1").getState ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+ assert(
+ statusMap("default-op-2").getState ==
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+
+ // Verify that extensions are echoed back per operation
+ statuses.foreach { status =>
+ val opExtensions = status.getExtensionsList.asScala.toSeq
+ assert(opExtensions.size == 1)
+ assert(opExtensions.head.is(classOf[StringValue]))
+ assert(
+ opExtensions.head
+ .unpack(classOf[StringValue])
+ .getValue == "custom_extension")
+ }
+ }
+
+ test("getOperationStatuses sends request-level extensions and echoes them in
the response") {
+ startDummyServer(0)
+ client = SparkConnectClient
+ .builder()
+ .connectionString(s"sc://localhost:${server.getPort}")
+ .build()
+
+ val reqExtension = PAny.pack(StringValue.of("request_extension"))
+
+ val response = client.getOperationStatuses(
+ operationIds = Seq("default-op-1"),
+ requestExtensions = Seq(reqExtension))
+
+ // Verify the operation status is returned
+ val statuses = response.getOperationStatusesList.asScala.toSeq
+ assert(statuses.size == 1)
+ assert(statuses.head.getOperationId == "default-op-1")
+
+ // Verify request-level extensions were echoed back in the response
+ val responseExtensions = response.getExtensionsList.asScala.toSeq
+ assert(responseExtensions.size == 1)
+ assert(responseExtensions.head.is(classOf[StringValue]))
+ assert(
+ responseExtensions.head
+ .unpack(classOf[StringValue])
+ .getValue == "request_extension")
+ }
+
test("client can set a custom operation id for ExecutePlan requests") {
startDummyServer(0)
client = SparkConnectClient
@@ -927,4 +1036,59 @@ class DummySparkConnectService() extends
SparkConnectServiceGrpc.SparkConnectSer
responseObserver.onNext(response)
responseObserver.onCompleted()
}
+
+ // Default operations stored in the mock session
+ private val defaultOperationStatuses = Map(
+ "default-op-1" ->
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+ "default-op-2" ->
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+
+ override def getStatus(
+ request: proto.GetStatusRequest,
+ responseObserver: StreamObserver[proto.GetStatusResponse]): Unit = {
+ val responseBuilder = proto.GetStatusResponse
+ .newBuilder()
+ .setSessionId(request.getSessionId)
+ .setServerSideSessionId(UUID.randomUUID().toString)
+
+ if (request.hasOperationStatus) {
+ val opStatusRequest = request.getOperationStatus
+ val requestedIds = opStatusRequest.getOperationIdsList.asScala.toSeq
+ // Collect operation-status-level extensions from the request to echo
back
+ val opStatusExtensions = opStatusRequest.getExtensionsList.asScala
+
+ if (requestedIds.isEmpty) {
+ // No specific IDs requested - return all default operations
+ defaultOperationStatuses.foreach { case (opId, state) =>
+ val statusBuilder = proto.GetStatusResponse.OperationStatus
+ .newBuilder()
+ .setOperationId(opId)
+ .setState(state)
+ opStatusExtensions.foreach(statusBuilder.addExtensions)
+ responseBuilder.addOperationStatuses(statusBuilder.build())
+ }
+ } else {
+ // Return status for each requested operation ID
+ // Unknown operations return OPERATION_STATE_UNKNOWN
+ requestedIds.foreach { opId =>
+ val state = defaultOperationStatuses.getOrElse(
+ opId,
+
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+ val statusBuilder = proto.GetStatusResponse.OperationStatus
+ .newBuilder()
+ .setOperationId(opId)
+ .setState(state)
+ opStatusExtensions.foreach(statusBuilder.addExtensions)
+ responseBuilder.addOperationStatuses(statusBuilder.build())
+ }
+ }
+ }
+
+ // Echo request-level extensions back in the response
+ request.getExtensionsList.asScala.foreach(responseBuilder.addExtensions)
+
+ responseObserver.onNext(responseBuilder.build())
+ responseObserver.onCompleted()
+ }
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 715da0df7349..ad39a3fb29f2 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -145,4 +145,17 @@ private[connect] class CustomSparkConnectBlockingStub(
}
}
}
+
+ def getStatus(request: GetStatusRequest): GetStatusResponse = {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
+ retryHandler.retry {
+ stubState.responseValidator.verifyResponse {
+ stub.getStatus(request)
+ }
+ }
+ }
+ }
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 0fa7d9ada48b..d9b9ba35b5e6 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -33,7 +33,7 @@ import io.grpc._
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.SparkThrowable
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
import org.apache.spark.internal.Logging
@@ -496,6 +496,43 @@ private[sql] class SparkConnectClient(
bstub.releaseSession(request.build())
}
+ /**
+ * Get status of operations in the session.
+ *
+ * @param operationIds
+ * Optional sequence of operation IDs to get status for. If empty, returns
status of all
+ * operations in the session.
+ * @param operationExtensions
+ * Optional per-operation extensions to include in the
OperationStatusRequest.
+ * @param requestExtensions
+ * Optional request-level extensions to include in the GetStatusRequest.
+ * @return
+ * The [[proto.GetStatusResponse]] for the requested operations, including
any extensions.
+ */
+ @Experimental
+ def getOperationStatuses(
+ operationIds: Seq[String] = Seq.empty,
+ operationExtensions: Seq[protobuf.Any] = Seq.empty,
+ requestExtensions: Seq[protobuf.Any] = Seq.empty):
proto.GetStatusResponse = {
+ val requestBuilder = proto.GetStatusRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+
+ serverSideSessionId.foreach(session =>
+ requestBuilder.setClientObservedServerSideSessionId(session))
+
+ val opStatusRequest =
proto.GetStatusRequest.OperationStatusRequest.newBuilder()
+ operationIds.foreach(opStatusRequest.addOperationIds)
+ operationExtensions.foreach(opStatusRequest.addExtensions)
+ requestBuilder.setOperationStatus(opStatusRequest)
+
+ requestExtensions.foreach(requestBuilder.addExtensions)
+
+ bstub.getStatus(requestBuilder.build())
+ }
+
private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
override def childValue(parent: mutable.Set[String]): mutable.Set[String]
= {
// Note: make a clone such that changes in the parent tags aren't
reflected in
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 77ede8e852e8..748e7623aba3 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -196,12 +196,14 @@ trait SparkConnectServerTest extends SharedSparkSession {
}
}
- protected def assertEventuallyNoActiveRpcs(): Unit = {
+ protected def eventuallyWithTimeout[T](f: => T): T = {
Eventually.eventually(timeout(eventuallyTimeout)) {
- assertNoActiveRpcs()
+ f
}
}
+ protected def assertEventuallyNoActiveRpcs(): Unit =
eventuallyWithTimeout(assertNoActiveRpcs())
+
protected def assertNoActiveExecutions(): Unit = {
SparkConnectService.executionManager.listActiveExecutions match {
case Left(_) => // cleaned up
@@ -209,11 +211,8 @@ trait SparkConnectServerTest extends SharedSparkSession {
}
}
- protected def assertEventuallyNoActiveExecutions(): Unit = {
- Eventually.eventually(timeout(eventuallyTimeout)) {
- assertNoActiveExecutions()
- }
- }
+ protected def assertEventuallyNoActiveExecutions(): Unit =
+ eventuallyWithTimeout(assertNoActiveExecutions())
protected def assertExecutionReleased(operationId: String): Unit = {
SparkConnectService.executionManager.listActiveExecutions match {
@@ -222,11 +221,8 @@ trait SparkConnectServerTest extends SharedSparkSession {
}
}
- protected def assertEventuallyExecutionReleased(operationId: String): Unit =
{
- Eventually.eventually(timeout(eventuallyTimeout)) {
- assertExecutionReleased(operationId)
- }
- }
+ protected def assertEventuallyExecutionReleased(operationId: String): Unit =
+ eventuallyWithTimeout(assertExecutionReleased(operationId))
// Get ExecutionHolder, assuming that only one execution is active
protected def getExecutionHolder: ExecuteHolder = {
@@ -235,11 +231,14 @@ trait SparkConnectServerTest extends SharedSparkSession {
executions.head
}
- protected def eventuallyGetExecutionHolder: ExecuteHolder = {
- Eventually.eventually(timeout(eventuallyTimeout)) {
- getExecutionHolder
- }
- }
+ protected def getExecutionHolderForOperation(opId: String): ExecuteHolder =
+
SparkConnectService.executionManager.listExecuteHolders.find(_.key.operationId
== opId).get
+
+ protected def eventuallyGetExecutionHolderForOperation(opId: String):
ExecuteHolder =
+ eventuallyWithTimeout(getExecutionHolderForOperation(opId))
+
+ protected def eventuallyGetExecutionHolder: ExecuteHolder =
+ eventuallyWithTimeout(getExecutionHolder)
protected def withClient(sessionId: String = defaultSessionId, userId:
String = defaultUserId)(
f: SparkConnectClient => Unit): Unit = {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
new file mode 100644
index 000000000000..17d16162e31f
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.jdk.CollectionConverters._
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.connect.proto
+import
org.apache.spark.connect.proto.GetStatusResponse.OperationStatus.OperationState
+import org.apache.spark.sql.connect.SparkConnectServerTest
+
+class GetStatusHandlerE2ESuite extends SparkConnectServerTest {
+
+ test("GetStatus tracks operation through lifecycle: RUNNING -> SUCCEEDED") {
+ withClient { client =>
+ val plan = buildPlan("SELECT java_method('java.lang.Thread', 'sleep',
2000L) as value")
+ val iter = client.execute(plan)
+ val operationId = iter.next().getOperationId
+
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ assert(SparkConnectService.executionManager.listExecuteHolders.length
== 1)
+ }
+
+ val runningStatuses =
+
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+ assert(runningStatuses.size == 1)
+ assert(runningStatuses.head.getState ==
OperationState.OPERATION_STATE_RUNNING)
+
+ while (iter.hasNext) iter.next()
+ assertEventuallyExecutionReleased(operationId)
+
+ // Use eventually to skip the intermediate TERMINATING status
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ val succeededStatuses =
+
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+ assert(succeededStatuses.size == 1)
+ assert(succeededStatuses.head.getState ==
OperationState.OPERATION_STATE_SUCCEEDED)
+ }
+ }
+ }
+
+ test("GetStatus tracks operation through lifecycle: RUNNING -> CANCELLED") {
+ withClient { client =>
+ val plan = buildPlan("SELECT java_method('java.lang.Thread', 'sleep',
3000L) as value")
+ val iter = client.execute(plan)
+ val operationId = iter.next().getOperationId
+
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ assert(SparkConnectService.executionManager.listExecuteHolders.length
== 1)
+ }
+
+ val runningStatuses =
+
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+ assert(runningStatuses.size == 1)
+ assert(runningStatuses.head.getState ==
OperationState.OPERATION_STATE_RUNNING)
+
+ client.interruptOperation(operationId)
+
+ intercept[SparkException] {
+ while (iter.hasNext) iter.next()
+ }
+
+ assertEventuallyExecutionReleased(operationId)
+
+ // Use eventually to skip the intermediate TERMINATING status
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ val cancelledStatuses =
+
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+ assert(cancelledStatuses.size == 1)
+ assert(cancelledStatuses.head.getState ==
OperationState.OPERATION_STATE_CANCELLED)
+ }
+ }
+ }
+
+ test("GetStatus returns FAILED for query with error") {
+ withClient { client =>
+ // Use assert_true with a dynamic condition to trigger a runtime failure
+ val plan = buildPlan("SELECT assert_true(id < 0) FROM range(1)")
+ val iter = client.execute(plan)
+ val operationId = iter.next().getOperationId
+
+ // Wait for the execution thread to finish before consuming the error
from
+ // the iterator. This prevents a race where consuming the error triggers
the
+ // client's ReleaseExecute, and removeExecuteHolder interrupts the
execution
+ // thread while it's still in its finally block, which would override the
+ // termination reason from Failed to Canceled.
+ val holder = eventuallyGetExecutionHolderForOperation(operationId)
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ assert(!holder.isExecuteThreadRunnerAlive())
+ }
+
+ intercept[SparkRuntimeException] {
+ while (iter.hasNext) iter.next()
+ }
+
+ assertEventuallyExecutionReleased(operationId)
+
+ // Use eventually to skip the intermediate TERMINATING status
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ val statuses =
+
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ assert(statuses.head.getOperationId == operationId)
+ assert(statuses.head.getState == OperationState.OPERATION_STATE_FAILED)
+ }
+ }
+ }
+
+ test("GetStatus returns UNKNOWN for non-existent operation") {
+ withClient { client =>
+ // Execute a simple query first to establish the session on the server
+ val plan = buildPlan("SELECT 1")
+ val iter = client.execute(plan)
+ while (iter.hasNext) iter.next()
+
+ // Query for a random operation ID that was never created
+ val nonExistentOperationId = UUID.randomUUID().toString
+ val statuses =
+
client.getOperationStatuses(Seq(nonExistentOperationId)).getOperationStatusesList.asScala
+ assert(statuses.size == 1)
+ assert(statuses.head.getOperationId == nonExistentOperationId)
+ assert(statuses.head.getState == OperationState.OPERATION_STATE_UNKNOWN)
+ }
+ }
+
+ test("GetStatus returns all operation statuses when no IDs specified") {
+ withClient { client =>
+ val plan1 = buildPlan("SELECT 1 as first_value")
+ val iter1 = client.execute(plan1)
+ val operationId1 = iter1.next().getOperationId
+ while (iter1.hasNext) iter1.next()
+ assertEventuallyExecutionReleased(operationId1)
+
+ val plan2 = buildPlan("SELECT assert_true(id < 0) FROM range(1)")
+ val iter2 = client.execute(plan2)
+ val operationId2 = iter2.next().getOperationId
+
+ // Wait for the execution thread to finish before consuming the error.
+ // Same race condition as in the FAILED test above.
+ val holder2 = eventuallyGetExecutionHolderForOperation(operationId2)
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ assert(!holder2.isExecuteThreadRunnerAlive())
+ }
+
+ intercept[SparkRuntimeException] {
+ while (iter2.hasNext) iter2.next()
+ }
+
+ assertEventuallyExecutionReleased(operationId2)
+
+ // Use eventually to skip the intermediate TERMINATING status
+ Eventually.eventually(timeout(eventuallyTimeout)) {
+ val statuses =
client.getOperationStatuses().getOperationStatusesList.asScala
+ val operationIdToStatus = statuses.map(s => s.getOperationId ->
s.getState).toMap
+
+ assert(statuses.size == 2)
+ assert(operationIdToStatus(operationId1) ==
OperationState.OPERATION_STATE_SUCCEEDED)
+ assert(operationIdToStatus(operationId2) ==
OperationState.OPERATION_STATE_FAILED)
+ }
+ }
+ }
+
+ test("GetStatus returns session info even when no operation status is
requested") {
+ // This test needs raw stub to send a request without operation_status
field
+ withRawBlockingStub { stub =>
+ // Execute a simple query to establish the session on the server.
+ val plan = buildPlan("SELECT 1")
+ val executeRequest = buildExecutePlanRequest(plan)
+ val sessionId = executeRequest.getSessionId
+ val operationId = executeRequest.getOperationId
+
+ val iter = stub.executePlan(executeRequest)
+ while (iter.hasNext) iter.next()
+
+ val releaseRequest = proto.ReleaseExecuteRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setOperationId(operationId)
+
.setReleaseAll(proto.ReleaseExecuteRequest.ReleaseAll.newBuilder().build())
+ .build()
+ stub.releaseExecute(releaseRequest)
+ assertEventuallyExecutionReleased(operationId)
+
+ val statusRequest = proto.GetStatusRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .build()
+
+ val response = stub.getStatus(statusRequest)
+
+ assert(response.getSessionId == sessionId)
+ assert(response.getServerSideSessionId.nonEmpty)
+ assert(response.getOperationStatusesList.isEmpty)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]