This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a83eec8235cc [SPARK-55691][CONNECT] GetStatus client
a83eec8235cc is described below

commit a83eec8235cc9241687b483524927f545dc84596
Author: Anastasiia Terenteva <[email protected]>
AuthorDate: Fri Feb 27 10:57:22 2026 -0400

    [SPARK-55691][CONNECT] GetStatus client
    
    ### What changes were proposed in this pull request?
    
    Client-side implementation of the GetStatus API:
    - Add client methods for requesting operation statuses.
      - GetStatus API is designed to support other request types that operation 
statuses in future. For those types, new methods will have to be created. This 
is a tradeoff between convenient client methods which encapsulate the most of 
the request building and flexibility.
    - Mark the client methods as Experimental and underscored for now, until 
enough confidence in the API stability is built.
    - Add client tests with mocked service.
    - Add E2E tests with real operation lifecycles.
    
    ### Why are the changes needed?
    
    GetStatus API allows to monitor status of executions in a session, which is 
particularly useful in multithreaded clients.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. It's a new Spark Connect API.
    
    ### How was this patch tested?
    
    - Add client tests with mocked service.
    - Add E2E tests with real operation lifecycles, which also test server-side 
changes.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Claude 4.6 Opus High
    
    Closes #54520 from terana/get-status-client.
    
    Authored-by: Anastasiia Terenteva <[email protected]>
    Signed-off-by: Herman van Hövell <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          |  53 +++++
 .../sql/tests/connect/client/test_client.py        | 174 ++++++++++++++++-
 .../connect/client/SparkConnectClientSuite.scala   | 164 ++++++++++++++++
 .../client/CustomSparkConnectBlockingStub.scala    |  13 ++
 .../sql/connect/client/SparkConnectClient.scala    |  39 +++-
 .../spark/sql/connect/SparkConnectServerTest.scala |  33 ++--
 .../connect/service/GetStatusHandlerE2ESuite.scala | 216 +++++++++++++++++++++
 7 files changed, 673 insertions(+), 19 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index f28d97e7472e..aa060df24e41 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1930,6 +1930,58 @@ class SparkConnectClient(object):
         except Exception as error:
             self._handle_error(error)
 
+    def _get_operation_statuses(
+        self,
+        operation_ids: Optional[List[str]] = None,
+        operation_extensions: Optional[List[any_pb2.Any]] = None,
+        request_extensions: Optional[List[any_pb2.Any]] = None,
+    ) -> "pb2.GetStatusResponse":
+        """
+        Get status of operations in the session.
+
+        Parameters
+        ----------
+        operation_ids : list of str, optional
+            List of operation IDs to get status for.
+            If None or empty, returns status of all operations in the session.
+        operation_extensions : list of google.protobuf.any_pb2.Any, optional
+            Per-operation extension messages to include in the 
OperationStatusRequest to request
+            additional per-operation information.
+        request_extensions : list of google.protobuf.any_pb2.Any, optional
+            Request-level extension messages to include in the 
GetStatusRequest.
+
+        Returns
+        -------
+        pb2.GetStatusResponse
+            The full GetStatusResponse, including operation_statuses and any 
extensions.
+        """
+        req = pb2.GetStatusRequest()
+        req.session_id = self._session_id
+        req.client_type = self._builder.userAgent
+        if self._user_id:
+            req.user_context.user_id = self._user_id
+        if self._server_session_id:
+            req.client_observed_server_side_session_id = 
self._server_session_id
+
+        req.operation_status.SetInParent()
+
+        if operation_ids:
+            req.operation_status.operation_ids.extend(operation_ids)
+        if operation_extensions:
+            req.operation_status.extensions.extend(operation_extensions)
+        if request_extensions:
+            req.extensions.extend(request_extensions)
+
+        try:
+            for attempt in self._retrying():
+                with attempt:
+                    resp = self._stub.GetStatus(req, 
metadata=self._builder.metadata())
+                    self._verify_response_integrity(resp)
+                    return resp
+            raise SparkConnectException("Invalid state during retry exception 
handling.")
+        except Exception as error:
+            self._handle_error(error)
+
     def add_tag(self, tag: str) -> None:
         self._throw_if_invalid_tag(tag)
         if not hasattr(self.thread_local, "tags"):
@@ -2169,6 +2221,7 @@ class SparkConnectClient(object):
             pb2.AnalyzePlanResponse,
             pb2.FetchErrorDetailsResponse,
             pb2.ReleaseSessionResponse,
+            pb2.GetStatusResponse,
         ],
     ) -> None:
         """
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 8daf82ad3f37..55faff5e9ed3 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -134,10 +134,25 @@ if should_test_connect:
 
         req: Optional[proto.ExecutePlanRequest]
 
-        def __init__(self, session_id: str):
+        OperationStatus = proto.GetStatusResponse.OperationStatus
+        DEFAULT_OPERATION_STATUSES = [
+            OperationStatus(
+                operation_id="default-op-1",
+                state=OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+            ),
+            OperationStatus(
+                operation_id="default-op-2",
+                state=OperationStatus.OperationState.OPERATION_STATE_RUNNING,
+            ),
+        ]
+
+        def __init__(self, session_id: str, operation_statuses=None):
             self._session_id = session_id
             self.req = None
             self.client_user_context_extensions = []
+            if operation_statuses is None:
+                operation_statuses = self.DEFAULT_OPERATION_STATUSES
+            self._operation_statuses = {s.operation_id: s for s in 
operation_statuses}
 
         def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
             self.req = req
@@ -191,6 +206,45 @@ if should_test_connect:
             resp.semantic_hash.result = 12345
             return resp
 
+        def GetStatus(self, req: proto.GetStatusRequest, metadata):
+            self.req = req
+            self.client_user_context_extensions = 
list(req.user_context.extensions)
+            self.received_custom_server_session_id = 
req.client_observed_server_side_session_id
+            resp = proto.GetStatusResponse(session_id=self._session_id)
+
+            # Echo top-level request extensions back in the response
+            if req.extensions:
+                resp.extensions.extend(req.extensions)
+
+            if not req.HasField("operation_status"):
+                return resp
+
+            # Collect operation-status-level extensions from the request to 
echo back
+            op_status_extensions = list(req.operation_status.extensions)
+
+            requested_ids = list(req.operation_status.operation_ids)
+            if len(requested_ids) == 0:
+                # Empty list — return all statuses
+                
resp.operation_statuses.extend(self._operation_statuses.values())
+                return resp
+
+            OperationStatus = proto.GetStatusResponse.OperationStatus
+            for op_id in requested_ids:
+                status = self._operation_statuses.get(op_id)
+                if status is not None:
+                    op_status = OperationStatus(
+                        operation_id=status.operation_id,
+                        state=status.state,
+                    )
+                else:
+                    op_status = OperationStatus(
+                        operation_id=op_id,
+                        
state=OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
+                    )
+                op_status.extensions.extend(op_status_extensions)
+                resp.operation_statuses.append(op_status)
+            return resp
+
     # The _cleanup_ml_cache invocation will hang in this test (no valid spark 
cluster)
     # and it blocks the test process exiting because it is registered as the 
atexit handler
     # in `SparkConnectClient` constructor. To bypass the issue, patch the 
method in the test.
@@ -396,6 +450,124 @@ class SparkConnectClientTestCase(unittest.TestCase):
         for resp in client._stub.ExecutePlan(req, metadata=None):
             assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
 
+    def test_get_operations_statuses_all(self):
+        """Test get_operations_statuses returns all operation statuses when no 
IDs specified."""
+        OperationStatus = proto.GetStatusResponse.OperationStatus
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        resp = client._get_operation_statuses()
+        result = list(resp.operation_statuses)
+        self.assertEqual(len(result), 2)
+        status_map = {s.operation_id: s.state for s in result}
+        self.assertEqual(
+            status_map["default-op-1"],
+            OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+        )
+        self.assertEqual(
+            status_map["default-op-2"],
+            OperationStatus.OperationState.OPERATION_STATE_RUNNING,
+        )
+
+    def test_get_operations_statuses_specific_ids(self):
+        """Test get_operations_statuses filters by specific operation IDs."""
+        OperationStatus = proto.GetStatusResponse.OperationStatus
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        resp = client._get_operation_statuses(operation_ids=["default-op-1", 
"unknown-op"])
+        result = list(resp.operation_statuses)
+        self.assertEqual(len(result), 2)
+        status_map = {s.operation_id: s.state for s in result}
+        self.assertEqual(
+            status_map["default-op-1"],
+            OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+        )
+        self.assertEqual(
+            status_map["unknown-op"],
+            OperationStatus.OperationState.OPERATION_STATE_UNKNOWN,
+        )
+        # Verify the request included the operation IDs
+        self.assertEqual(
+            set(mock.req.operation_status.operation_ids), {"default-op-1", 
"unknown-op"}
+        )
+
+    def test_get_operations_statuses_empty(self):
+        """Test get_operations_statuses returns empty list when no operations 
exist."""
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
+        mock = MockService(client._session_id, operation_statuses=[])
+        client._stub = mock
+
+        resp = client._get_operation_statuses()
+        self.assertEqual(len(list(resp.operation_statuses)), 0)
+
+    def test_get_operations_statuses_with_operation_extensions(self):
+        """Test get_operations_statuses passes operation-level extensions and 
echoes them back per operation."""
+        from google.protobuf import any_pb2, wrappers_pb2
+
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        op_ext = any_pb2.Any()
+        op_ext.Pack(wrappers_pb2.StringValue(value="op_extension"))
+
+        resp = client._get_operation_statuses(
+            operation_ids=["default-op-1", "default-op-2"],
+            operation_extensions=[op_ext],
+        )
+        result = list(resp.operation_statuses)
+        self.assertEqual(len(result), 2)
+        self.assertEqual({s.operation_id for s in result}, {"default-op-1", 
"default-op-2"})
+
+        # Verify operation-level extensions were included in the request
+        self.assertEqual(len(mock.req.operation_status.extensions), 1)
+        unpacked = wrappers_pb2.StringValue()
+        mock.req.operation_status.extensions[0].Unpack(unpacked)
+        self.assertEqual(unpacked.value, "op_extension")
+
+        # Verify operation-level extensions were echoed back per operation
+        for op_status in result:
+            self.assertEqual(len(op_status.extensions), 1)
+            echoed = wrappers_pb2.StringValue()
+            op_status.extensions[0].Unpack(echoed)
+            self.assertEqual(echoed.value, "op_extension")
+
+    def test_get_operations_statuses_with_request_extensions(self):
+        """Test _get_operation_statuses sends request-level extensions and 
echoes them back."""
+        from google.protobuf import any_pb2, wrappers_pb2
+
+        client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
+        mock = MockService(client._session_id)
+        client._stub = mock
+
+        req_ext = any_pb2.Any()
+        req_ext.Pack(wrappers_pb2.StringValue(value="request_extension"))
+
+        resp = client._get_operation_statuses(
+            operation_ids=["default-op-1"],
+            request_extensions=[req_ext],
+        )
+
+        # Verify the operation status is returned
+        result = list(resp.operation_statuses)
+        self.assertEqual(len(result), 1)
+        self.assertEqual(result[0].operation_id, "default-op-1")
+
+        # Verify request-level extensions were included in the request
+        self.assertEqual(len(mock.req.extensions), 1)
+        unpacked = wrappers_pb2.StringValue()
+        mock.req.extensions[0].Unpack(unpacked)
+        self.assertEqual(unpacked.value, "request_extension")
+
+        # Verify request-level extensions were echoed back in the response
+        self.assertEqual(len(resp.extensions), 1)
+        resp_echoed = wrappers_pb2.StringValue()
+        resp.extensions[0].Unpack(resp_echoed)
+        self.assertEqual(resp_echoed.value, "request_extension")
+
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SparkConnectClientReattachTestCase(unittest.TestCase):
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 1cfc2d3eb09f..e833657984dc 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
+import com.google.protobuf.{Any => PAny, StringValue}
 import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata, 
MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor, 
Status, StatusRuntimeException}
 import io.grpc.netty.NettyServerBuilder
 import io.grpc.stub.StreamObserver
@@ -548,6 +549,114 @@ class SparkConnectClientSuite extends ConnectFunSuite {
     observer.onCompleted()
   }
 
+  test("getOperationStatuses returns operation statuses for requested IDs") {
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}")
+      .build()
+
+    val response =
+      client.getOperationStatuses(Seq("default-op-1", "unknown-op"))
+    val statuses = response.getOperationStatusesList.asScala.toSeq
+    assert(statuses.size == 2)
+    assert(statuses.map(_.getOperationId).toSet == Set("default-op-1", 
"unknown-op"))
+
+    val statusMap = statuses.map(s => s.getOperationId -> s.getState).toMap
+    assert(
+      statusMap("default-op-1") ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+    assert(
+      statusMap("unknown-op") ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+  }
+
+  test("getOperationStatuses with no IDs returns all operations from server") {
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}")
+      .build()
+
+    val response = client.getOperationStatuses()
+    val statuses = response.getOperationStatusesList.asScala.toSeq
+    assert(statuses.size == 2)
+    assert(statuses.map(_.getOperationId).toSet == Set("default-op-1", 
"default-op-2"))
+
+    val statusMap = statuses.map(s => s.getOperationId -> s.getState).toMap
+    assert(
+      statusMap("default-op-1") ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+    assert(
+      statusMap("default-op-2") ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+  }
+
+  test("getOperationStatuses sends extensions and returns them per operation") 
{
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}")
+      .build()
+
+    val extension = PAny.pack(StringValue.of("custom_extension"))
+
+    val response = client.getOperationStatuses(
+      operationIds = Seq("default-op-1", "default-op-2"),
+      operationExtensions = Seq(extension))
+
+    // Verify operation statuses are returned
+    val statuses = response.getOperationStatusesList.asScala.toSeq
+    assert(statuses.size == 2)
+
+    val statusMap = statuses.map(s => s.getOperationId -> s).toMap
+    assert(
+      statusMap("default-op-1").getState ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED)
+    assert(
+      statusMap("default-op-2").getState ==
+        
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+
+    // Verify that extensions are echoed back per operation
+    statuses.foreach { status =>
+      val opExtensions = status.getExtensionsList.asScala.toSeq
+      assert(opExtensions.size == 1)
+      assert(opExtensions.head.is(classOf[StringValue]))
+      assert(
+        opExtensions.head
+          .unpack(classOf[StringValue])
+          .getValue == "custom_extension")
+    }
+  }
+
+  test("getOperationStatuses sends request-level extensions and echoes them in 
the response") {
+    startDummyServer(0)
+    client = SparkConnectClient
+      .builder()
+      .connectionString(s"sc://localhost:${server.getPort}")
+      .build()
+
+    val reqExtension = PAny.pack(StringValue.of("request_extension"))
+
+    val response = client.getOperationStatuses(
+      operationIds = Seq("default-op-1"),
+      requestExtensions = Seq(reqExtension))
+
+    // Verify the operation status is returned
+    val statuses = response.getOperationStatusesList.asScala.toSeq
+    assert(statuses.size == 1)
+    assert(statuses.head.getOperationId == "default-op-1")
+
+    // Verify request-level extensions were echoed back in the response
+    val responseExtensions = response.getExtensionsList.asScala.toSeq
+    assert(responseExtensions.size == 1)
+    assert(responseExtensions.head.is(classOf[StringValue]))
+    assert(
+      responseExtensions.head
+        .unpack(classOf[StringValue])
+        .getValue == "request_extension")
+  }
+
   test("client can set a custom operation id for ExecutePlan requests") {
     startDummyServer(0)
     client = SparkConnectClient
@@ -927,4 +1036,59 @@ class DummySparkConnectService() extends 
SparkConnectServiceGrpc.SparkConnectSer
     responseObserver.onNext(response)
     responseObserver.onCompleted()
   }
+
+  // Default operations stored in the mock session
+  private val defaultOperationStatuses = Map(
+    "default-op-1" ->
+      
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_SUCCEEDED,
+    "default-op-2" ->
+      
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_RUNNING)
+
+  override def getStatus(
+      request: proto.GetStatusRequest,
+      responseObserver: StreamObserver[proto.GetStatusResponse]): Unit = {
+    val responseBuilder = proto.GetStatusResponse
+      .newBuilder()
+      .setSessionId(request.getSessionId)
+      .setServerSideSessionId(UUID.randomUUID().toString)
+
+    if (request.hasOperationStatus) {
+      val opStatusRequest = request.getOperationStatus
+      val requestedIds = opStatusRequest.getOperationIdsList.asScala.toSeq
+      // Collect operation-status-level extensions from the request to echo 
back
+      val opStatusExtensions = opStatusRequest.getExtensionsList.asScala
+
+      if (requestedIds.isEmpty) {
+        // No specific IDs requested - return all default operations
+        defaultOperationStatuses.foreach { case (opId, state) =>
+          val statusBuilder = proto.GetStatusResponse.OperationStatus
+            .newBuilder()
+            .setOperationId(opId)
+            .setState(state)
+          opStatusExtensions.foreach(statusBuilder.addExtensions)
+          responseBuilder.addOperationStatuses(statusBuilder.build())
+        }
+      } else {
+        // Return status for each requested operation ID
+        // Unknown operations return OPERATION_STATE_UNKNOWN
+        requestedIds.foreach { opId =>
+          val state = defaultOperationStatuses.getOrElse(
+            opId,
+            
proto.GetStatusResponse.OperationStatus.OperationState.OPERATION_STATE_UNKNOWN)
+          val statusBuilder = proto.GetStatusResponse.OperationStatus
+            .newBuilder()
+            .setOperationId(opId)
+            .setState(state)
+          opStatusExtensions.foreach(statusBuilder.addExtensions)
+          responseBuilder.addOperationStatuses(statusBuilder.build())
+        }
+      }
+    }
+
+    // Echo request-level extensions back in the response
+    request.getExtensionsList.asScala.foreach(responseBuilder.addExtensions)
+
+    responseObserver.onNext(responseBuilder.build())
+    responseObserver.onCompleted()
+  }
 }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 715da0df7349..ad39a3fb29f2 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -145,4 +145,17 @@ private[connect] class CustomSparkConnectBlockingStub(
       }
     }
   }
+
+  def getStatus(request: GetStatusRequest): GetStatusResponse = {
+    grpcExceptionConverter.convert(
+      request.getSessionId,
+      request.getUserContext,
+      request.getClientType) {
+      retryHandler.retry {
+        stubState.responseValidator.verifyResponse {
+          stub.getStatus(request)
+        }
+      }
+    }
+  }
 }
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 0fa7d9ada48b..d9b9ba35b5e6 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -33,7 +33,7 @@ import io.grpc._
 
 import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
 import org.apache.spark.SparkThrowable
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.UserContext
 import org.apache.spark.internal.Logging
@@ -496,6 +496,43 @@ private[sql] class SparkConnectClient(
     bstub.releaseSession(request.build())
   }
 
+  /**
+   * Get status of operations in the session.
+   *
+   * @param operationIds
+   *   Optional sequence of operation IDs to get status for. If empty, returns 
status of all
+   *   operations in the session.
+   * @param operationExtensions
+   *   Optional per-operation extensions to include in the 
OperationStatusRequest.
+   * @param requestExtensions
+   *   Optional request-level extensions to include in the GetStatusRequest.
+   * @return
+   *   The [[proto.GetStatusResponse]] for the requested operations, including 
any extensions.
+   */
+  @Experimental
+  def getOperationStatuses(
+      operationIds: Seq[String] = Seq.empty,
+      operationExtensions: Seq[protobuf.Any] = Seq.empty,
+      requestExtensions: Seq[protobuf.Any] = Seq.empty): 
proto.GetStatusResponse = {
+    val requestBuilder = proto.GetStatusRequest
+      .newBuilder()
+      .setUserContext(userContext)
+      .setSessionId(sessionId)
+      .setClientType(userAgent)
+
+    serverSideSessionId.foreach(session =>
+      requestBuilder.setClientObservedServerSideSessionId(session))
+
+    val opStatusRequest = 
proto.GetStatusRequest.OperationStatusRequest.newBuilder()
+    operationIds.foreach(opStatusRequest.addOperationIds)
+    operationExtensions.foreach(opStatusRequest.addExtensions)
+    requestBuilder.setOperationStatus(opStatusRequest)
+
+    requestExtensions.foreach(requestBuilder.addExtensions)
+
+    bstub.getStatus(requestBuilder.build())
+  }
+
   private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
     override def childValue(parent: mutable.Set[String]): mutable.Set[String] 
= {
       // Note: make a clone such that changes in the parent tags aren't 
reflected in
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 77ede8e852e8..748e7623aba3 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -196,12 +196,14 @@ trait SparkConnectServerTest extends SharedSparkSession {
     }
   }
 
-  protected def assertEventuallyNoActiveRpcs(): Unit = {
+  protected def eventuallyWithTimeout[T](f: => T): T = {
     Eventually.eventually(timeout(eventuallyTimeout)) {
-      assertNoActiveRpcs()
+      f
     }
   }
 
+  protected def assertEventuallyNoActiveRpcs(): Unit = 
eventuallyWithTimeout(assertNoActiveRpcs())
+
   protected def assertNoActiveExecutions(): Unit = {
     SparkConnectService.executionManager.listActiveExecutions match {
       case Left(_) => // cleaned up
@@ -209,11 +211,8 @@ trait SparkConnectServerTest extends SharedSparkSession {
     }
   }
 
-  protected def assertEventuallyNoActiveExecutions(): Unit = {
-    Eventually.eventually(timeout(eventuallyTimeout)) {
-      assertNoActiveExecutions()
-    }
-  }
+  protected def assertEventuallyNoActiveExecutions(): Unit =
+    eventuallyWithTimeout(assertNoActiveExecutions())
 
   protected def assertExecutionReleased(operationId: String): Unit = {
     SparkConnectService.executionManager.listActiveExecutions match {
@@ -222,11 +221,8 @@ trait SparkConnectServerTest extends SharedSparkSession {
     }
   }
 
-  protected def assertEventuallyExecutionReleased(operationId: String): Unit = 
{
-    Eventually.eventually(timeout(eventuallyTimeout)) {
-      assertExecutionReleased(operationId)
-    }
-  }
+  protected def assertEventuallyExecutionReleased(operationId: String): Unit =
+    eventuallyWithTimeout(assertExecutionReleased(operationId))
 
   // Get ExecutionHolder, assuming that only one execution is active
   protected def getExecutionHolder: ExecuteHolder = {
@@ -235,11 +231,14 @@ trait SparkConnectServerTest extends SharedSparkSession {
     executions.head
   }
 
-  protected def eventuallyGetExecutionHolder: ExecuteHolder = {
-    Eventually.eventually(timeout(eventuallyTimeout)) {
-      getExecutionHolder
-    }
-  }
+  protected def getExecutionHolderForOperation(opId: String): ExecuteHolder =
+    
SparkConnectService.executionManager.listExecuteHolders.find(_.key.operationId 
== opId).get
+
+  protected def eventuallyGetExecutionHolderForOperation(opId: String): 
ExecuteHolder =
+    eventuallyWithTimeout(getExecutionHolderForOperation(opId))
+
+  protected def eventuallyGetExecutionHolder: ExecuteHolder =
+    eventuallyWithTimeout(getExecutionHolder)
 
   protected def withClient(sessionId: String = defaultSessionId, userId: 
String = defaultUserId)(
       f: SparkConnectClient => Unit): Unit = {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
new file mode 100644
index 000000000000..17d16162e31f
--- /dev/null
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/GetStatusHandlerE2ESuite.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.jdk.CollectionConverters._
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.connect.proto
+import 
org.apache.spark.connect.proto.GetStatusResponse.OperationStatus.OperationState
+import org.apache.spark.sql.connect.SparkConnectServerTest
+
+class GetStatusHandlerE2ESuite extends SparkConnectServerTest {
+
+  test("GetStatus tracks operation through lifecycle: RUNNING -> SUCCEEDED") {
+    withClient { client =>
+      val plan = buildPlan("SELECT java_method('java.lang.Thread', 'sleep', 
2000L) as value")
+      val iter = client.execute(plan)
+      val operationId = iter.next().getOperationId
+
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 1)
+      }
+
+      val runningStatuses =
+        
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+      assert(runningStatuses.size == 1)
+      assert(runningStatuses.head.getState == 
OperationState.OPERATION_STATE_RUNNING)
+
+      while (iter.hasNext) iter.next()
+      assertEventuallyExecutionReleased(operationId)
+
+      // Use eventually to skip the intermediate TERMINATING status
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        val succeededStatuses =
+          
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+        assert(succeededStatuses.size == 1)
+        assert(succeededStatuses.head.getState == 
OperationState.OPERATION_STATE_SUCCEEDED)
+      }
+    }
+  }
+
+  test("GetStatus tracks operation through lifecycle: RUNNING -> CANCELLED") {
+    withClient { client =>
+      val plan = buildPlan("SELECT java_method('java.lang.Thread', 'sleep', 
3000L) as value")
+      val iter = client.execute(plan)
+      val operationId = iter.next().getOperationId
+
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(SparkConnectService.executionManager.listExecuteHolders.length 
== 1)
+      }
+
+      val runningStatuses =
+        
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+      assert(runningStatuses.size == 1)
+      assert(runningStatuses.head.getState == 
OperationState.OPERATION_STATE_RUNNING)
+
+      client.interruptOperation(operationId)
+
+      intercept[SparkException] {
+        while (iter.hasNext) iter.next()
+      }
+
+      assertEventuallyExecutionReleased(operationId)
+
+      // Use eventually to skip the intermediate TERMINATING status
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        val cancelledStatuses =
+          
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+        assert(cancelledStatuses.size == 1)
+        assert(cancelledStatuses.head.getState == 
OperationState.OPERATION_STATE_CANCELLED)
+      }
+    }
+  }
+
+  test("GetStatus returns FAILED for query with error") {
+    withClient { client =>
+      // Use assert_true with a dynamic condition to trigger a runtime failure
+      val plan = buildPlan("SELECT assert_true(id < 0) FROM range(1)")
+      val iter = client.execute(plan)
+      val operationId = iter.next().getOperationId
+
+      // Wait for the execution thread to finish before consuming the error 
from
+      // the iterator. This prevents a race where consuming the error triggers 
the
+      // client's ReleaseExecute, and removeExecuteHolder interrupts the 
execution
+      // thread while it's still in its finally block, which would override the
+      // termination reason from Failed to Canceled.
+      val holder = eventuallyGetExecutionHolderForOperation(operationId)
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(!holder.isExecuteThreadRunnerAlive())
+      }
+
+      intercept[SparkRuntimeException] {
+        while (iter.hasNext) iter.next()
+      }
+
+      assertEventuallyExecutionReleased(operationId)
+
+      // Use eventually to skip the intermediate TERMINATING status
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        val statuses =
+          
client.getOperationStatuses(Seq(operationId)).getOperationStatusesList.asScala
+        assert(statuses.size == 1)
+        assert(statuses.head.getOperationId == operationId)
+        assert(statuses.head.getState == OperationState.OPERATION_STATE_FAILED)
+      }
+    }
+  }
+
+  test("GetStatus returns UNKNOWN for non-existent operation") {
+    withClient { client =>
+      // Execute a simple query first to establish the session on the server
+      val plan = buildPlan("SELECT 1")
+      val iter = client.execute(plan)
+      while (iter.hasNext) iter.next()
+
+      // Query for a random operation ID that was never created
+      val nonExistentOperationId = UUID.randomUUID().toString
+      val statuses =
+        
client.getOperationStatuses(Seq(nonExistentOperationId)).getOperationStatusesList.asScala
+      assert(statuses.size == 1)
+      assert(statuses.head.getOperationId == nonExistentOperationId)
+      assert(statuses.head.getState == OperationState.OPERATION_STATE_UNKNOWN)
+    }
+  }
+
+  test("GetStatus returns all operation statuses when no IDs specified") {
+    withClient { client =>
+      val plan1 = buildPlan("SELECT 1 as first_value")
+      val iter1 = client.execute(plan1)
+      val operationId1 = iter1.next().getOperationId
+      while (iter1.hasNext) iter1.next()
+      assertEventuallyExecutionReleased(operationId1)
+
+      val plan2 = buildPlan("SELECT assert_true(id < 0) FROM range(1)")
+      val iter2 = client.execute(plan2)
+      val operationId2 = iter2.next().getOperationId
+
+      // Wait for the execution thread to finish before consuming the error.
+      // Same race condition as in the FAILED test above.
+      val holder2 = eventuallyGetExecutionHolderForOperation(operationId2)
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        assert(!holder2.isExecuteThreadRunnerAlive())
+      }
+
+      intercept[SparkRuntimeException] {
+        while (iter2.hasNext) iter2.next()
+      }
+
+      assertEventuallyExecutionReleased(operationId2)
+
+      // Use eventually to skip the intermediate TERMINATING status
+      Eventually.eventually(timeout(eventuallyTimeout)) {
+        val statuses = 
client.getOperationStatuses().getOperationStatusesList.asScala
+        val operationIdToStatus = statuses.map(s => s.getOperationId -> 
s.getState).toMap
+
+        assert(statuses.size == 2)
+        assert(operationIdToStatus(operationId1) == 
OperationState.OPERATION_STATE_SUCCEEDED)
+        assert(operationIdToStatus(operationId2) == 
OperationState.OPERATION_STATE_FAILED)
+      }
+    }
+  }
+
+  test("GetStatus returns session info even when no operation status is 
requested") {
+    // This test needs raw stub to send a request without operation_status 
field
+    withRawBlockingStub { stub =>
+      // Execute a simple query to establish the session on the server.
+      val plan = buildPlan("SELECT 1")
+      val executeRequest = buildExecutePlanRequest(plan)
+      val sessionId = executeRequest.getSessionId
+      val operationId = executeRequest.getOperationId
+
+      val iter = stub.executePlan(executeRequest)
+      while (iter.hasNext) iter.next()
+
+      val releaseRequest = proto.ReleaseExecuteRequest
+        .newBuilder()
+        .setUserContext(userContext)
+        .setSessionId(sessionId)
+        .setOperationId(operationId)
+        
.setReleaseAll(proto.ReleaseExecuteRequest.ReleaseAll.newBuilder().build())
+        .build()
+      stub.releaseExecute(releaseRequest)
+      assertEventuallyExecutionReleased(operationId)
+
+      val statusRequest = proto.GetStatusRequest
+        .newBuilder()
+        .setUserContext(userContext)
+        .setSessionId(sessionId)
+        .build()
+
+      val response = stub.getStatus(statusRequest)
+
+      assert(response.getSessionId == sessionId)
+      assert(response.getServerSideSessionId.nonEmpty)
+      assert(response.getOperationStatusesList.isEmpty)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to